diff --git a/WORKSPACE b/WORKSPACE index 74ea14d0fd7..622fa4d1412 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,34 +49,34 @@ remote_config_workspace() # Apple and Swift rules. http_archive( name = "build_bazel_rules_apple", - sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1", - urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"], + sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366", + urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"], ) # https://github.com/bazelbuild/rules_apple/releases http_archive( name = "build_bazel_rules_swift", - sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311", - urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"], + sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0", + urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"], ) # https://github.com/bazelbuild/rules_swift/releases http_archive( name = "build_bazel_apple_support", - sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471", - urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"], + sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033", + urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"], ) # https://github.com/bazelbuild/apple_support/releases http_archive( name = "bazel_skylib", - sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", - urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"], + sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"], ) # https://github.com/bazelbuild/bazel-skylib/releases http_archive( name = "com_github_apple_swift_swift_protobuf", type = "zip", - strip_prefix = "swift-protobuf-1.5.0/", - urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"], + strip_prefix = "swift-protobuf-1.6.0/", + urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"], ) # https://github.com/apple/swift-protobuf/releases http_file( name = "xctestrunner", executable = 1, - urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"], + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"], ) # https://github.com/google/xctestrunner/releases # Use `swift_rules_dependencies` to fetch the toolchains. With the # `git_repository` rules above, the following call will skip redefining them. diff --git a/arm_compiler.BUILD b/arm_compiler.BUILD index db2e9bbe1e1..cffe3fac70d 100644 --- a/arm_compiler.BUILD +++ b/arm_compiler.BUILD @@ -3,56 +3,56 @@ package(default_visibility = ["//visibility:public"]) filegroup( name = "gcc", srcs = [ - "bin/arm-linux-gnueabihf-gcc", + "bin/arm-rpi-linux-gnueabihf-gcc", ], ) filegroup( name = "ar", srcs = [ - "bin/arm-linux-gnueabihf-ar", + "bin/arm-rpi-linux-gnueabihf-ar", ], ) filegroup( name = "ld", srcs = [ - "bin/arm-linux-gnueabihf-ld", + "bin/arm-rpi-linux-gnueabihf-ld", ], ) filegroup( name = "nm", srcs = [ - "bin/arm-linux-gnueabihf-nm", + "bin/arm-rpi-linux-gnueabihf-nm", ], ) filegroup( name = "objcopy", srcs = [ - "bin/arm-linux-gnueabihf-objcopy", + "bin/arm-rpi-linux-gnueabihf-objcopy", ], ) filegroup( name = "objdump", srcs = [ - "bin/arm-linux-gnueabihf-objdump", + "bin/arm-rpi-linux-gnueabihf-objdump", ], ) filegroup( name = "strip", srcs = [ - "bin/arm-linux-gnueabihf-strip", + "bin/arm-rpi-linux-gnueabihf-strip", ], ) filegroup( name = "as", srcs = [ - "bin/arm-linux-gnueabihf-as", + "bin/arm-rpi-linux-gnueabihf-as", ], ) diff --git a/configure.py b/configure.py index 4391fab507a..115c170238f 100644 --- a/configure.py +++ b/configure.py @@ -1388,11 +1388,18 @@ def main(): if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): + tensor_rt_question = ( + 'Do you wish to build TensorFlow with TensorRT support? NB! There ' + + 'are known ODR violations between TensorRT and cuDNN that may result ' + + 'in application crashes and/or data corruption. Please see ' + + 'https://github.com/tensorflow/tensorflow/issues/32480 for details.') + set_action_env_var( environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False, + question=tensor_rt_question, bazel_config_name='tensorrt') environ_save = dict(environ_cp) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index ed4f10e0f77..ae6e582a421 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -159,7 +159,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { return new TF_DeprecatedSession({session}); } else { DCHECK_EQ(nullptr, session); @@ -332,7 +332,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { // TODO(nolivia): check this on a subset of the graph instead of all of // it. status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { session->graph->mu.unlock(); return false; } @@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { *graph_def.mutable_library() = graph.flib_def().ToProto(); session->graph->mu.unlock(); status->status = session->session->Extend(std::move(graph_def)); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { // Contract is we always delete input_values[i]. return false; } @@ -382,7 +382,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); - if (TF_GetCode(status) != TF_OK) return false; + if (!status->status.ok()) return false; } return true; } @@ -439,7 +439,7 @@ static void TF_Run_Helper( // Serialize back to upstream client, who now owns the new buffer if (run_metadata != nullptr) { status->status = MessageToBuffer(run_metadata_proto, run_metadata); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; } } else { // NOTE(zongheng): PRun does not support RunOptions yet. @@ -459,7 +459,7 @@ static void TF_Run_Helper( continue; } c_outputs[i] = TF_TensorFromTensor(src, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; } } @@ -516,7 +516,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -555,7 +555,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { status->status = tensorflow::LoadLibrary( library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, &lib_handle->op_list.length); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { delete lib_handle; return nullptr; } @@ -983,7 +983,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, TF_Tensor* value, TF_Status* status) { Tensor t; status->status = TF_TensorToTensor(value, &t); - if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, @@ -993,13 +993,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, std::vector t; t.reserve(num_values); - for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) { + for (int i = 0; i < num_values && status->status.ok(); ++i) { Tensor v; status->status = TF_TensorToTensor(values[i], &v); t.emplace_back(v); } - if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, @@ -1048,11 +1048,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, /*consume=*/true); - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { // Run shape inference function for newly added node. status->status = desc->graph->refiner.AddNode(ret); } - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; } else if (ret != nullptr) { @@ -1101,7 +1101,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); - if (TF_GetCode(status) != TF_OK) return -1; + if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { status->status = InvalidArgument("Output arg '", arg_name, "' not found"); @@ -1123,7 +1123,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); - if (TF_GetCode(status) != TF_OK) return -1; + if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { status->status = InvalidArgument("Input arg '", arg_name, "' not found"); @@ -1142,6 +1142,16 @@ TF_Output TF_OperationInput(TF_Input oper_in) { return {ToOperation(edge->src()), edge->src_output()}; } +void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs, + int max_inputs) { + for (auto* edge : oper->node.in_edges()) { + if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) { + inputs[edge->dst_input()] = {ToOperation(edge->src()), + edge->src_output()}; + } + } +} + int TF_OperationOutputNumConsumers(TF_Output oper_out) { int count = 0; for (const auto* edge : oper_out.oper->node.out_edges()) { @@ -1221,7 +1231,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, TF_Status* status) { TF_AttrMetadata metadata; const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return metadata; + if (!status->status.ok()) return metadata; switch (attr->value_case()) { #define SINGLE_CASE(kK, attr_type, size_expr) \ case tensorflow::AttrValue::kK: \ @@ -1328,7 +1338,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, void* value, size_t max_length, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; if (attr->value_case() != tensorflow::AttrValue::kS) { status->status = InvalidArgument("Attribute '", attr_name, "' is not a string"); @@ -1346,7 +1356,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, int max_values, void* storage, size_t storage_size, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1379,7 +1389,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ int max_values, TF_Status* status) { \ const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (TF_GetCode(status) != TF_OK) return; \ + if (!status->status.ok()) return; \ if (attr->value_case() != tensorflow::AttrValue::kList) { \ status->status = \ InvalidArgument("Value for '", attr_name, "' is not a list."); \ @@ -1401,7 +1411,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, PartialTensorShape shape; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { value[i] = shape.dim_size(i); @@ -1415,7 +1425,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, std::vector shapes; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; auto len = std::min(static_cast(shapes.size()), num_shapes); int64_t* p = storage; int storage_left = storage_size; @@ -1443,7 +1453,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, const char* attr_name, TF_Buffer* value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; if (attr->value_case() != tensorflow::AttrValue::kShape) { status->status = InvalidArgument("Value for '", attr_name, "' is not a shape."); @@ -1457,7 +1467,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, TF_Buffer** values, int max_values, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1467,7 +1477,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, for (int i = 0; i < len; ++i) { values[i] = TF_NewBuffer(); status->status = MessageToBuffer(attr->list().shape(i), values[i]); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { // Delete everything allocated to far, the operation has failed. for (int j = 0; j <= i; ++j) { TF_DeleteBuffer(values[j]); @@ -1482,7 +1492,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, *value = nullptr; Tensor t; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; *value = TF_TensorFromTensor(t, status); } @@ -1491,7 +1501,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Status* status) { std::vector ts; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { values[i] = TF_TensorFromTensor(ts[i], status); @@ -1502,7 +1512,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; status->status = MessageToBuffer(*attr, output_attr_value); } @@ -1583,7 +1593,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, { mutex_lock l(graph->mu); status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; } status->status = MessageToBuffer(*op_def, output_op_def); } @@ -1701,7 +1711,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tensorflow::ImportGraphDefResults results; status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, &graph->refiner, &results); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; // Add new nodes to name_map for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { @@ -1755,7 +1765,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( auto results = new TF_ImportGraphDefResults(); mutex_lock l(graph->mu); GraphImportGraphDefLocked(graph, def, options, results, status); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { delete results; return nullptr; } @@ -1813,7 +1823,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input)); // TODO(skyewm): set placeholder shape TF_Operation* oper = TF_FinishOperation(desc, status); - if (TF_GetCode(status) != TF_OK) return false; + if (!status->status.ok()) return false; *input = {oper, 0}; return true; } @@ -1958,7 +1968,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, body_graph, body_inputs, body_outputs, name}; - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { FreeWhileResources(¶ms); return EmptyWhileParams(); } @@ -2160,7 +2170,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); @@ -2208,7 +2218,7 @@ TF_Session* TF_LoadSessionFromSavedModel( status->status = tensorflow::LoadSavedModel(session_options->options, run_options_proto, export_dir, tag_set, &bundle); - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session // extends using GraphDefs. The Graph instance is different, but equivalent @@ -2221,11 +2231,11 @@ TF_Session* TF_LoadSessionFromSavedModel( GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), import_opts, &results, status); TF_DeleteImportGraphDefOptions(import_opts); - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; if (meta_graph_def != nullptr) { status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; } TF_Session* session = new TF_Session(bundle.session.release(), graph); @@ -2325,7 +2335,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); - if (TF_GetCode(status) == TF_OK) { + if (status->status.ok()) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -2387,9 +2397,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, tensor, graph->refiner, *graph->graph.op_registry(), graph->graph.versions().producer(), &evaluated, &result_tensor); if (evaluated) { - DCHECK(TF_GetCode(status) == TF_OK); + DCHECK(status->status.ok()); *result = TF_TensorFromTensor(result_tensor, status); - if (TF_GetCode(status) != TF_OK) evaluated = false; + if (!status->status.ok()) evaluated = false; } return evaluated; } @@ -2444,7 +2454,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { TF_DeleteBuffer(ret); return nullptr; } @@ -2456,7 +2466,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { TF_DeleteBuffer(ret); return nullptr; } @@ -2468,7 +2478,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { tensorflow::GetRegisteredKernelsForOp(name); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (TF_GetCode(status) != TF_OK) { + if (!status->status.ok()) { TF_DeleteBuffer(ret); return nullptr; } @@ -2498,7 +2508,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len, std::unique_ptr out_server; status->status = tensorflow::NewServer(server_def, &out_server); - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; return new TF_Server(std::move(out_server)); #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 4eeedd4cbc9..0c413f6ebae 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, // producer.index) to consumer.oper's input (given by consumer.index). TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); +// Get list of all inputs of a specific operation. `inputs` must point to +// an array of length at least `max_inputs` (ideally set to +// TF_OperationNumInputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of inputs of +// an operation. +TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, + TF_Output* inputs, + int max_inputs); + // Get the number of current consumers of a specific output of an // operation. Note that this number can change when new operations // are added to the graph. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index f04f0175696..1fe9276ffc6 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, return createTFEDequeue(ctx, TF_VARIANT, queue, status); } -static void CheckOk(TF_Status* status) { - CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); -} - void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { auto* status = TF_NewStatus(); if (!TFE_TensorHandleIsConcrete(handle)) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 2742bead4e4..95005971e91 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -9,7 +9,6 @@ load( ) load( "//tensorflow/core/platform:default/build_config.bzl", - "tf_additional_device_tracer_test_flags", "tf_kernel_tests_linkstatic", ) load( @@ -27,6 +26,7 @@ tf_cuda_library( "c_api.cc", "c_api_debug.cc", "c_api_experimental.h", + "c_api_internal.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -237,8 +237,7 @@ tf_cuda_cc_test( srcs = [ "c_api_experimental_test.cc", ], - args = - ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), + args = ["--heap_check=local"], extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index ff16eaf322d..10a1fa42f57 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/platform.h" // NOLINT #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -100,32 +101,34 @@ string DeviceName(const tensorflow::Device* d) { tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, - std::unique_ptr* device_mgr) { + std::unique_ptr* device_mgr) { std::vector> remote_devices; - tensorflow::Status status; - // TODO(nareshmodi) do this in parallel instead of serially. - for (const string& remote_worker : remote_workers) { - tensorflow::Notification n; + tensorflow::mutex remote_devices_mu; + int num_remote_workers = remote_workers.size(); + tensorflow::BlockingCounter counter(num_remote_workers); + std::vector statuses(num_remote_workers); + for (int i = 0; i < num_remote_workers; i++) { tensorflow::NewRemoteDevices( - tensorflow::Env::Default(), worker_cache, remote_worker, - [&status, &n, &remote_devices]( + tensorflow::Env::Default(), worker_cache, remote_workers[i], + [i, &statuses, &counter, &remote_devices, &remote_devices_mu]( const tensorflow::Status& s, std::vector* devices) { - status = s; + statuses[i] = s; if (s.ok()) { + tensorflow::mutex_lock l(remote_devices_mu); for (tensorflow::Device* d : *devices) { remote_devices.emplace_back(d); } } - n.Notify(); + counter.DecrementCount(); }); - n.WaitForNotification(); } - std::unique_ptr remote_device_mgr( - new tensorflow::StaticDeviceMgr(std::move(remote_devices))); - - TF_RETURN_IF_ERROR(status); - + counter.Wait(); + for (int i = 0; i < num_remote_workers; i++) { + TF_RETURN_IF_ERROR(statuses[i]); + } + auto remote_device_mgr = absl::make_unique(); + TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices))); *device_mgr = std::move(remote_device_mgr); return tensorflow::Status::OK(); } @@ -135,11 +138,15 @@ tensorflow::Status CreateRemoteContexts( int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, const tensorflow::eager::CreateContextRequest& base_request) { - for (int i = 0; i < remote_workers.size(); i++) { + int num_remote_workers = remote_workers.size(); + tensorflow::BlockingCounter counter(num_remote_workers); + std::vector statuses(num_remote_workers); + for (int i = 0; i < num_remote_workers; i++) { const string& remote_worker = remote_workers[i]; tensorflow::eager::CreateContextRequest request(base_request); - tensorflow::eager::CreateContextResponse response; + tensorflow::eager::CreateContextResponse* response = + new tensorflow::eager::CreateContextResponse(); request.set_context_id(context_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, @@ -159,16 +166,17 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::errors::Internal( "Cannot find a client for the given target:", remote_worker); } - tensorflow::Notification n; - tensorflow::Status status; - // TODO(nareshmodi) do this in parallel instead of serially. eager_client->CreateContextAsync( - &request, &response, [&status, &n](const tensorflow::Status& s) { - status = s; - n.Notify(); + &request, response, + [i, &statuses, &counter, response](const tensorflow::Status& s) { + statuses[i] = s; + delete response; + counter.DecrementCount(); }); - n.WaitForNotification(); - TF_RETURN_IF_ERROR(status); + } + counter.Wait(); + for (int i = 0; i < num_remote_workers; i++) { + TF_RETURN_IF_ERROR(statuses[i]); } return tensorflow::Status::OK(); } @@ -215,7 +223,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); - std::unique_ptr remote_device_mgr; + std::unique_ptr remote_device_mgr; LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( remote_workers, grpc_server->master_env()->worker_cache, &remote_device_mgr)); @@ -247,7 +255,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( LOG_AND_RETURN_IF_ERROR( CreateRemoteContexts(remote_workers, context_id, keep_alive_secs, server_def, remote_eager_workers.get(), - ctx->context->Executor()->Async(), base_request)); + ctx->context->Executor().Async(), base_request)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(context_id); @@ -564,7 +572,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { const tensorflow::Tensor* t = nullptr; tensorflow::TensorHandle* h_cpu = nullptr; status->status = EagerCopyToDevice( - handle, handle->Context(), handle->Context()->Executor(), + handle, handle->Context(), &handle->Context()->Executor(), handle->Context()->HostCPU(), false, &h_cpu); if (!status->status.ok()) { return nullptr; @@ -596,33 +604,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - const char* name = op_or_function_name; // Shorthand - const tensorflow::AttrTypeMap* types; - bool is_function = false; - status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); - if (!status->status.ok()) { - return nullptr; - } - if (!is_function) { - const tensorflow::OpDef* op_def; - status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); - if (!status->status.ok()) { - return nullptr; - } - return new TFE_Op(ctx, name, false, types, - new TFE_OpInferenceContext(op_def)); - } - if (!ctx->context->FindFunctionByName(name)) { - status->status = tensorflow::errors::NotFound( - "'", name, - "' is neither a type of a primitive operation nor a name " - "of a function registered in binary running on ", - tensorflow::port::Hostname(), - ". Make sure the operation or function is " - "registered in the binary running in this process."); - return nullptr; - } - return new TFE_Op(ctx, name, true, types, nullptr); + return NewOrResetOp(ctx, op_or_function_name, status, + /* op_to_reset= */ nullptr); } void TFE_DeleteOp(TFE_Op* op) { delete op; } @@ -916,7 +899,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, return nullptr; } status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, - ctx->context->Executor(), + &ctx->context->Executor(), device, false, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); @@ -967,7 +950,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - status->status = ctx->context->Executor()->WaitForAllPendingNodes(); + status->status = ctx->context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); @@ -979,9 +962,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, TF_Status* status) { TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); for (const auto& attr : func.attr()) { - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; + if (!status->status.ok()) return nullptr; } return func_op; } @@ -1029,7 +1012,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } break; case tensorflow::AttrValue::kFunc: { const auto func_op = GetFunc(ctx, default_value.func(), status); - if (TF_GetCode(status) != TF_OK) return; + if (!status->status.ok()) return; // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList // require TFE_Op* and just convert it internally a NameAttrValue, so // consider adding an overload to the C API to make this case easier. diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index a9ad77198e7..a40a435065f 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -28,6 +28,16 @@ limitations under the License. using tensorflow::string; +void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status, TFE_Op* op_to_reset) { + if (op_to_reset) { + NewOrResetOp(ctx, op_or_function_name, status, op_to_reset); + } else { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "op_to_reset should not be nullptr"); + } +} + void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { op->operation.ConsumeInput(h->handle); } @@ -597,5 +607,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - return new TFE_Executor(ctx->context->Executor()); + return new TFE_Executor(&ctx->context->Executor()); } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index e5a9459faff..cafef707706 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -22,6 +22,10 @@ limitations under the License. extern "C" { #endif +TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx, + const char* op_or_function_name, + TF_Status* status, TFE_Op* op_to_reset); + TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index ab76ad10adc..95165b0c5dc 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -84,11 +84,6 @@ void ExecuteWithProfiling(bool async) { string profile_proto_str = profile_proto.DebugString(); if (!gpu_device_name.empty()) { EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0")); - // device name with "stream:all" is collected by Device Tracer. -#ifndef TENSORFLOW_USE_ROCM - // ROCm platform does not yet support stream level tracing - EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); -#endif } // "/host:CPU" is collected by TraceMe EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU")); diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc new file mode 100644 index 00000000000..772fae13faf --- /dev/null +++ b/tensorflow/c/eager/c_api_internal.cc @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/c_api_internal.h" + +#include "tensorflow/core/platform/host_info.h" + +TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status, TFE_Op* op_to_reset) { + const char* name = op_or_function_name; // Shorthand + const tensorflow::AttrTypeMap* types; + bool is_function = false; + status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); + if (!status->status.ok()) { + return nullptr; + } + auto create_or_reset = [&op_to_reset, &ctx, &name, &types]( + bool is_function, + TFE_OpInferenceContext* inference_ctx) -> TFE_Op* { + if (op_to_reset) { + op_to_reset->Reset(ctx, name, is_function, types, inference_ctx); + return op_to_reset; + } else { + return new TFE_Op(ctx, name, is_function, types, inference_ctx); + } + }; + + if (!is_function) { + const tensorflow::OpDef* op_def; + status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); + if (!status->status.ok()) { + return nullptr; + } + return create_or_reset(false, new TFE_OpInferenceContext(op_def)); + } + if (!ctx->context->FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; + } + return create_or_reset(true, nullptr); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 5efed2ca76d..964e558a01f 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -133,10 +133,25 @@ struct TFE_Op { : operation(ctx->context, op, is_function, t), inference_ctx(inference_ctx) {} + void Clear() { + operation.Clear(); + inference_ctx.reset(); + } + + void Reset(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t, + TFE_OpInferenceContext* infer_ctx) { + operation.Reset(ctx->context, op, is_function, t, nullptr); + inference_ctx.reset(infer_ctx); + } + tensorflow::EagerOperation operation; std::unique_ptr inference_ctx; }; +TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status, TFE_Op* op_to_reset = nullptr); + struct TFE_Profiler { explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index d3b755fee6e..6702e26e66d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1069,10 +1069,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { // still fail. TF_SetStatus(status, TF_OK, ""); TFE_DeleteTensorHandle(retvals[0]); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TF_SetStatus(status, TF_OK, ""); retvals[0] = nullptr; TFE_Execute(matmul2, &retvals[0], &num_retvals, status); EXPECT_NE(TF_OK, TF_GetCode(status)); - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); TFE_ExecutorClearError(executor); TFE_ExecutorWaitForAllPendingNodes(executor, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 919e2dfc638..cc13dcf9976 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -292,7 +292,9 @@ string ToCamelCase(const string& str) { bool cap = true; while (i < str.size()) { const char c = str[i++]; - if (c == joiner) { + if (c == '>') { + cap = true; + } else if (c == joiner) { cap = true; } else if (cap) { result += toupper(c); @@ -304,6 +306,21 @@ string ToCamelCase(const string& str) { return result; } +string SeparateNamespaces(const string& str) { + string result; + const char joiner = '_'; + size_t i = 0; + while (i < str.size()) { + const char c = str[i++]; + if (c == '>') { + result += joiner; + } else { + result += c; + } + } + return result; +} + // Returns a pair. The string is the C++ type name to be used for // attr_type when defining an object of that type. The bool is a flag to // indicate whether to treat the type as const when accepting the C++ type as an @@ -549,7 +566,7 @@ struct OpInfo { OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases) : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) { - op_name = api_def.endpoint(0).name(); + op_name = SeparateNamespaces(api_def.endpoint(0).name()); InferOpAttributes(graph_op_def, &inferred_input_attrs); has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs); arg_types.push_back("const ::tensorflow::Scope&"); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index b2affdd9993..c9786fa8b7e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -127,8 +127,29 @@ cc_library( ) tf_cc_test( - name = "loader_test", - srcs = ["loader_test.cc"], + name = "saved_model_bundle_test", + srcs = ["saved_model_bundle_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":constants", + ":loader", + ":signature_constants", + ":tag_constants", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "saved_model_bundle_lite_test", + srcs = ["saved_model_bundle_lite_test.cc"], data = [ ":saved_model_half_plus_two", ], diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a3b80fbdba5..0aec4f42aee 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -299,6 +299,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, } // namespace +SavedModelBundleInterface::~SavedModelBundleInterface() {} + Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, @@ -323,6 +325,133 @@ Status LoadSavedModel(const SessionOptions& session_options, return status; } +namespace { +// Session wrapper that prevents calls to Session::Create(), Session::Extend(), +// and the deprecated partial-run methods. +// +// Limiting the available methods on a returned Session gives us the option +// to replace the Session with a cut-down implementation, without breaking any +// users. +class LiteSessionWrapper : public Session { + public: + explicit LiteSessionWrapper(std::unique_ptr wrapped) + : wrapped_(std::move(wrapped)) {} + + Status Create(const GraphDef& graph) override { + return errors::Unimplemented("Session::Create()"); + } + Status Create(GraphDef&& graph) override { + return errors::Unimplemented("Session::Create()"); + } + + Status Extend(const GraphDef& graph) override { + return errors::Unimplemented("Session::Extend()"); + } + Status Extend(GraphDef&& graph) override { + return errors::Unimplemented("Session::Extend()"); + } + + Status Run(const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) override { + return wrapped_->Run(inputs, output_tensor_names, target_node_names, + outputs); + } + + Status Create(const RunOptions& run_options, const GraphDef& graph) override { + return errors::Unimplemented("Session::Create()"); + } + Status Extend(const RunOptions& run_options, const GraphDef& graph) override { + return errors::Unimplemented("Session::Extend()"); + } + Status Create(const RunOptions& run_options, GraphDef&& graph) override { + return errors::Unimplemented("Session::Create()"); + } + Status Extend(const RunOptions& run_options, GraphDef&& graph) override { + return errors::Unimplemented("Session::Extend()"); + } + Status Close(const RunOptions& run_options) override { + return wrapped_->Close(run_options); + } + + Status Run(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, RunMetadata* run_metadata) override { + return wrapped_->Run(run_options, inputs, output_tensor_names, + target_node_names, outputs, run_metadata); + } + + Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override { + return errors::Unimplemented("Session::PRunSetup()"); + } + + Status PRun(const string& handle, + const std::vector>& inputs, + const std::vector& output_names, + std::vector* outputs) override { + return errors::Unimplemented("Session::PRun()"); + } + + Status ListDevices(std::vector* response) override { + return wrapped_->ListDevices(response); + } + + Status Close() override { return wrapped_->Close(); } + + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override { + return wrapped_->MakeCallable(callable_options, out_handle); + } + + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override { + return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); + } + + Status RunCallable( + CallableHandle handle, const std::vector& feed_tensors, + std::vector* fetch_tensors, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) override { + return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata, threadpool_options); + } + + Status ReleaseCallable(CallableHandle handle) override { + return wrapped_->ReleaseCallable(handle); + } + + private: + const std::unique_ptr wrapped_; +}; +} // namespace + +Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, const string& export_dir, + const std::unordered_set& tags, + SavedModelBundleLite* const bundle) { + SavedModelBundle legacy_bundle; + SessionOptions rewritten_options(session_options); + rewritten_options.config.mutable_experimental() + ->set_optimize_for_static_graph(true); + // TODO(mrry): Consider specializing the session creation to reduce peak + // RAM consumption by using `Session::Create(GraphDef&&)`. + TF_RETURN_IF_ERROR(LoadSavedModel(session_options, run_options, export_dir, + tags, &legacy_bundle)); + *bundle = SavedModelBundleLite( + absl::make_unique(std::move(legacy_bundle.session)), + std::move(*legacy_bundle.meta_graph_def.mutable_signature_def())); + return Status::OK(); +} + bool MaybeSavedModelDirectory(const string& export_dir) { const string saved_model_pb_path = io::JoinPath(export_dir, kSavedModelFilenamePb); diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index a8e098fa544..74094a0cc23 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -27,31 +27,96 @@ limitations under the License. namespace tensorflow { -/// SavedModel representation once the SavedModel is loaded from storage. -struct SavedModelBundle { - std::unique_ptr session; - MetaGraphDef meta_graph_def; +/// Represents a SavedModel that is loaded from storage. +class SavedModelBundleInterface { + public: + virtual ~SavedModelBundleInterface(); + /// Returns the TensorFlow Session that can be used to interact with the + /// SavedModel. + virtual Session* GetSession() const = 0; + + /// Returns a map from signature name to SignatureDef for all signatures in + /// in the SavedModel. + virtual const protobuf::Map& GetSignatures() const = 0; +}; + +/// SavedModel representation once the SavedModel is loaded from storage. +/// +/// NOTE: Prefer to use SavedModelBundleLite in new code, as it consumes less +/// RAM. +struct SavedModelBundle : public SavedModelBundleInterface { /// A TensorFlow Session does not Close itself on destruction. To avoid /// resource leaks, we explicitly call Close on Sessions that we create. - ~SavedModelBundle() { + ~SavedModelBundle() override { if (session) { session->Close().IgnoreError(); } } SavedModelBundle() = default; + + Session* GetSession() const override { return session.get(); } + const protobuf::Map& GetSignatures() const override { + return meta_graph_def.signature_def(); + } + + std::unique_ptr session; + MetaGraphDef meta_graph_def; }; -/// Loads a SavedModel from the specified export directory. The meta graph def +// A version of SavedModelBundle that avoids storing a potentially large +// MetaGraphDef. Prefer to use SavedModelBundleLite in new code. +class SavedModelBundleLite : public SavedModelBundleInterface { + public: + SavedModelBundleLite() = default; + SavedModelBundleLite& operator=(SavedModelBundleLite&& other) = default; + + SavedModelBundleLite(std::unique_ptr session, + protobuf::Map signatures) + : session_(std::move(session)), signatures_(std::move(signatures)) {} + + /// A TensorFlow Session does not Close itself on destruction. To avoid + /// resource leaks, we explicitly call Close on Sessions that we create. + ~SavedModelBundleLite() override { + if (session_) { + session_->Close().IgnoreError(); + } + } + + Session* GetSession() const override { return session_.get(); } + const protobuf::Map& GetSignatures() const override { + return signatures_; + } + + private: + std::unique_ptr session_; + protobuf::Map signatures_; +}; + +/// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to -/// the set of tags used at SavedModel build time. Returns a SavedModel bundle -/// with a session and the requested meta graph def, if found. +/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in +/// *bundle with a session and the requested MetaGraphDef, if found. +/// +/// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code. Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle); +/// Loads a SavedModel from the specified export directory. The MetaGraphDef +/// to be loaded is identified by the supplied tags, corresponding exactly to +/// the set of tags used at SavedModel build time. Stores a SavedModel bundle +/// in *bundle with a session created from the requested MetaGraphDef if found. +/// +/// This overload creates a SavedModelBundleLite, which consumes less RAM than +/// an equivalent SavedModelBundle. +Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, const string& export_dir, + const std::unordered_set& tags, + SavedModelBundleLite* const bundle); + /// Checks whether the provided directory could contain a SavedModel. Note that /// the method does not load any data by itself. If the method returns `false`, /// the export directory definitely does not contain a SavedModel. If the method diff --git a/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc new file mode 100644 index 00000000000..7ef0b828425 --- /dev/null +++ b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc @@ -0,0 +1,244 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/loader.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/cc/saved_model/tag_constants.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestDataPbTxt[] = + "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataMainOp[] = + "cc/saved_model/testdata/half_plus_two_main_op/00000123"; +constexpr char kTestDataSharded[] = + "cc/saved_model/testdata/half_plus_two/00000123"; +constexpr char kTestDataInitOpV2[] = + "cc/saved_model/testdata/half_plus_two_v2/00000123"; + +class LoaderTest : public ::testing::Test { + protected: + LoaderTest() {} + + string MakeSerializedExample(float x) { + tensorflow::Example example; + auto* feature_map = example.mutable_features()->mutable_feature(); + (*feature_map)["x"].mutable_float_list()->add_value(x); + return example.SerializeAsString(); + } + + void ValidateAssets(const string& export_dir, + const SavedModelBundleLite& bundle) { + const string asset_directory = + io::JoinPath(export_dir, kSavedModelAssetsDirectory); + const string asset_filename = "foo.txt"; + const string asset_filepath = io::JoinPath(asset_directory, asset_filename); + TF_EXPECT_OK(Env::Default()->FileExists(asset_filepath)); + + std::vector path_outputs; + TF_ASSERT_OK( + bundle.GetSession()->Run({}, {"filename_tensor:0"}, {}, &path_outputs)); + ASSERT_EQ(1, path_outputs.size()); + + test::ExpectTensorEqual( + test::AsTensor({"foo.txt"}, TensorShape({})), path_outputs[0]); + } + + void CheckSavedModelBundle(const string& export_dir, + const SavedModelBundleLite& bundle) { + ValidateAssets(export_dir, bundle); + // Retrieve the regression signature from the bundle. + const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y"); + + const string input_name = signature_def.inputs().at(kRegressInputs).name(); + const string output_name = + signature_def.outputs().at(kRegressOutputs).name(); + + std::vector serialized_examples; + for (float x : {0, 1, 2, 3}) { + serialized_examples.push_back(MakeSerializedExample(x)); + } + + // Validate the half plus two behavior. + Tensor input = + test::AsTensor(serialized_examples, TensorShape({4})); + std::vector outputs; + TF_ASSERT_OK(bundle.GetSession()->Run({{input_name, input}}, {output_name}, + {}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + test::ExpectTensorEqual( + outputs[0], + test::AsTensor({2, 2.5, 3, 3.5}, TensorShape({4, 1}))); + } +}; + +// Test for resource leaks related to TensorFlow session closing requirements +// when loading and unloading large numbers of SavedModelBundles. +// TODO(sukritiramesh): Increase run iterations and move outside of the test +// suite. +TEST_F(LoaderTest, ResourceLeakTest) { + SavedModelBundleLite bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + for (int i = 0; i < 100; ++i) { + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); + } +} + +TEST_F(LoaderTest, TagMatch) { + SavedModelBundleLite bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + +TEST_F(LoaderTest, NoTagMatch) { + SavedModelBundleLite bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {"missing-tag"}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(absl::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) + << st.error_message(); +} + +TEST_F(LoaderTest, NoTagMatchMultiple) { + SavedModelBundleLite bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe, "missing-tag"}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(absl::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) + << st.error_message(); +} + +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundleLite bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget)) + << st.error_message(); +} + +TEST_F(LoaderTest, PbtxtFormat) { + SavedModelBundleLite bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + +TEST_F(LoaderTest, MainOpFormat) { + SavedModelBundleLite bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + +TEST_F(LoaderTest, InvalidExportPath) { + SavedModelBundleLite bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); +} + +TEST_F(LoaderTest, MaybeSavedModelDirectory) { + // Valid SavedModel directory. + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + EXPECT_TRUE(MaybeSavedModelDirectory(export_dir)); + + // Directory that does not exist. + const string missing_export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + EXPECT_FALSE(MaybeSavedModelDirectory(missing_export_dir)); + + // Directory that exists but is an invalid SavedModel location. + const string invalid_export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model"); + EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); +} + +TEST_F(LoaderTest, SavedModelInitOpV2Format) { + SavedModelBundleLite bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc similarity index 98% rename from tensorflow/cc/saved_model/loader_test.cc rename to tensorflow/cc/saved_model/saved_model_bundle_test.cc index aa2031d17d2..a2ce9b5f5e9 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -71,8 +71,7 @@ class LoaderTest : public ::testing::Test { const SavedModelBundle& bundle) { ValidateAssets(export_dir, bundle); // Retrieve the regression signature from meta graph def. - const auto signature_def_map = bundle.meta_graph_def.signature_def(); - const auto signature_def = signature_def_map.at("regress_x_to_y"); + const auto& signature_def = bundle.GetSignatures().at("regress_x_to_y"); const string input_name = signature_def.inputs().at(kRegressInputs).name(); const string output_name = diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 1ebfe235b4d..2b15b12ec24 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library") @@ -38,7 +38,7 @@ cc_library( ":xla_cpu_device", ":xla_cpu_jit", "//tensorflow/compiler/plugin", - ] + if_cuda([ + ] + if_cuda_or_rocm([ ":xla_gpu_device", ":xla_gpu_jit", ]), @@ -61,7 +61,7 @@ cc_library( cc_library( name = "xla_gpu_jit", visibility = ["//visibility:public"], - deps = if_cuda([ + deps = if_cuda_or_rocm([ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 6498436fbd9..b0c78469118 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -130,17 +130,24 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( return uncompilable_nodes; } -bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const { +bool RecursiveCompilabilityChecker::HasXLAKernel( + const Node& node, string* uncompilable_reason) const { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by // IsCompilableCall(). - if (node.type_string() == "SymbolicGradient") return false; + if (node.type_string() == "SymbolicGradient") { + *uncompilable_reason = + "SymbolicGradient should be handled by IsCompilableCall()."; + return false; + } if (node.type_string() == "Const") { // Skip Const op with type DT_STRING, since XLA doesn't support it, but the // registered Const KernelDef says that it does, to support no-op Assert for // tfcompile. const AttrValue* attr = node.attrs().Find("dtype"); if (attr != nullptr && attr->type() == DT_STRING) { + *uncompilable_reason = + "Const op with type DT_STRING is not supported by XLA."; return false; } } @@ -150,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const { // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; + *uncompilable_reason = "Identity with unsafe cast."; return false; } - return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok(); + Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr); + if (!s.ok()) { + *uncompilable_reason = s.error_message(); + return false; + } + return true; } // Tests whether 'if_node' is compilable. Every operator in the then_branch and @@ -336,16 +349,17 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + string uncompilable_reason; if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) { if (!IsCompilableCall(node.def(), lib_runtime, stack_trace, encapsulating_function, uncompilable_nodes)) { LogNotCompilable(node, "unsupported function"); return false; } - } else if (!HasXLAKernel(node)) { - absl::string_view uncompilable_reason = "unsupported op"; - MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - encapsulating_function, uncompilable_nodes); + } else if (!HasXLAKernel(node, &uncompilable_reason)) { + MaybeMarkUncompilableNode( + absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace, + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 04639df14a1..43b2689b522 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -247,7 +247,8 @@ class RecursiveCompilabilityChecker { absl::c_any_of(node.output_types(), is_variant); } - bool HasXLAKernel(const Node& node) const; + bool HasXLAKernel(const Node& node, + string* uncompilable_reason = nullptr) const; static void MaybeMarkUncompilableNode( const absl::string_view reason, diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 0dd3b8141c9..cdce1e92799 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -125,7 +125,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) { const auto& uncompilable_nodes_inside_function = node_info_it->second.second; ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); - EXPECT_EQ("unsupported op", uncompilable_node_info.uncompilable_reason); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "unsupported op")); ASSERT_EQ(1, uncompilable_node_info.stack_trace.size()); ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name); } @@ -167,7 +168,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) { EXPECT_EQ("D", node_stack.at(0).name); EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name); EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name); - EXPECT_EQ("unsupported op", node_info.uncompilable_reason); + EXPECT_TRUE( + absl::StrContains(node_info.uncompilable_reason, "unsupported op")); } TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) { @@ -246,7 +248,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) { stacktrace_second_node_info.function_name); EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name); - EXPECT_EQ("unsupported op", node_info.uncompilable_reason); + EXPECT_TRUE( + absl::StrContains(node_info.uncompilable_reason, "unsupported op")); } TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { @@ -322,7 +325,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { stacktrace_second_node_info.function_name); EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name); - EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason); + EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason, + "unsupported op")); NameAttrList function_two; function_two.set_name(kUncompilableFunctionTwoName); @@ -345,7 +349,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { node_two_stacktrace_second_node.function_name); EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name); - EXPECT_EQ("unsupported op", uncompilable_node_two.uncompilable_reason); + EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason, + "unsupported op")); } } // namespace diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index dbe0d66b0c8..4e0f62b4351 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -45,7 +45,8 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:lhlo", diff --git a/tensorflow/compiler/mlir/g3doc/README.md b/tensorflow/compiler/mlir/g3doc/README.md new file mode 100644 index 00000000000..39734828d19 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/README.md @@ -0,0 +1,3 @@ +# TensorFlow MLIR + +These are the docs for: https://www.tensorflow.org/mlir diff --git a/tensorflow/compiler/mlir/g3doc/_book.yaml b/tensorflow/compiler/mlir/g3doc/_book.yaml new file mode 100644 index 00000000000..9e8aa655c09 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/_book.yaml @@ -0,0 +1,24 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +- include: /api_docs/_upper_tabs_api.yaml +# Dropdown menu +- name: Resources + path: /resources + is_default: true + menu: + - include: /resources/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide & Tutorials + contents: + - title: Overview + path: /mlir/overview + - heading: Dialects + - title: TensorFlow + path: /mlir/tf_ops + - title: TensorFlow Lite + path: /mlir/tfl_ops + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/mlir/g3doc/_index.yaml b/tensorflow/compiler/mlir/g3doc/_index.yaml new file mode 100644 index 00000000000..9090eefe875 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/_index.yaml @@ -0,0 +1,48 @@ +book_path: /mlir/_book.yaml +project_path: /mlir/_project.yaml +description: +landing_page: + custom_css_path: /site-assets/css/style.css + rows: + - heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow. + items: + - description: > + The MLIR project defines a common intermediate representation (IR) that + unifies the infrastructure required to execute high performance machine + learning models in TensorFlow and similar ML frameworks. This project + will include the application of HPC techniques, along with integration of + search algorithms like reinforcement learning. MLIR aims to reduce the + cost to bring up new hardware, and improve usability for existing + TensorFlow users. + + - code_block: | +
+        // Syntactically similar to LLVM:
+        func @testFunction(%arg0: i32) {
+          %x = call @thingToCall(%arg0) : (i32) -> i32
+          br ^bb1
+        ^bb1:
+          %y = addi %x, %x : i32
+          return %y : i32
+        }
+        
+ + - classname: devsite-landing-row-cards + items: + - heading: "Multi-Level Intermediate Representation for Compiler Infrastructure" + youtube_id: qzljG6DKgic + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=qzljG6DKgic + - heading: "A new intermediate representation and compiler framework" + image_path: /resources/images/tf-logo-card-16x9.png + path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d + buttons: + - label: Read on TensorFlow blog + path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d + - heading: TensorFlow MLIR on GitHub + image_path: /resources/images/github-card-16x9.png + path: https://github.com/tensorflow/mlir + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/mlir diff --git a/tensorflow/compiler/mlir/g3doc/_project.yaml b/tensorflow/compiler/mlir/g3doc/_project.yaml new file mode 100644 index 00000000000..be0e46ac0ca --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/_project.yaml @@ -0,0 +1,11 @@ +name: TensorFlow MLIR +breadcrumb_name: MLIR +home_url: /mlir/ +parent_project_metadata_path: /_project.yaml +description: > + MLIR unifies the infrastructure for high-performance ML models in TensorFlow. +use_site_branding: true +hide_from_products_list: true +content_license: cc-apache +buganizer_id: 443907 +include: /_project_included.yaml diff --git a/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg b/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg new file mode 100644 index 00000000000..aec0986ba02 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/mlir/g3doc/overview.md b/tensorflow/compiler/mlir/g3doc/overview.md new file mode 100644 index 00000000000..885c04b9588 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/overview.md @@ -0,0 +1,5 @@ +# MLIR overview + +## Overview + +MLIR overview diagram diff --git a/tensorflow/compiler/mlir/g3doc/tf_ops.md b/tensorflow/compiler/mlir/g3doc/tf_ops.md new file mode 100644 index 00000000000..cedeba5dae1 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/tf_ops.md @@ -0,0 +1,2761 @@ + +# Operation definition +## tf.Abs (TF::AbsOp) +Computes the absolute value of a tensor. + +### Description: + +Given a tensor `x`, this operation returns a tensor containing the absolute +value of each element in `x`. For example, if x is an input element and y is +an output element, this operation computes \\(y = |x|\\). + +### Operands: +1. `x`: tensor of floating-point or 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 32/64-bit integer values + +## tf.AddN (TF::AddNOp) +Add all input tensors element wise. + +### Description: + + +### Operands: +1. `inputs`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `sum`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values + +## tf.Add (TF::AddOp) +Returns x + y element-wise. + +### Description: + +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number or TensorFlow string type values +1. `y`: tensor of number or TensorFlow string type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number or TensorFlow string type values + +## tf.AddV2 (TF::AddV2Op) +Returns x + y element-wise. + +### Description: + +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.AvgPool (TF::AvgPoolOp) +Performs average pooling on the input. + +### Description: + +Each entry in `output` is the mean of the corresponding size `ksize` +window in `value`. + +### Operands: +1. `value`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | +| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | +| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | +| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.BatchToSpaceND (TF::BatchToSpaceNDOp) +BatchToSpace for N-D tensors of type T. + +### Description: + +This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape +`block_shape + [batch]`, interleaves these blocks back into the grid defined by +the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as +the input. The spatial dimensions of this intermediate result are then +optionally cropped according to `crops` to produce the output. This is the +reverse of SpaceToBatch. See below for a precise description. + +### Operands: +1. `input`: tensor of tf.dtype values +1. `block_shape`: tensor of 32/64-bit integer values +1. `crops`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tcrops` | `Attribute` | derived attribute attribute | +| `Tblock_shape` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.BiasAdd (TF::BiasAddOp) +Adds `bias` to `value`. + +### Description: + +This is a special case of `tf.add` where `bias` is restricted to be 1-D. +Broadcasting is supported, so `value` may have any number of dimensions. + +### Operands: +1. `value`: tensor of number values +1. `bias`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.Bitcast (TF::BitcastOp) + +Bitcasts a tensor from one type to another without copying data. + + +### Description: + +Given a tensor `input`, this operation returns a tensor that has the same buffer +data as `input` with datatype `type`. + +If the input datatype `T` is larger than the output datatype `type` then the +shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. + +If `T` is smaller than `type`, the operator requires that the rightmost +dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from +[..., sizeof(`type`)/sizeof(`T`)] to [...]. + +tf.bitcast() and tf.cast() work differently when real dtype is casted as a complex dtype +(e.g. tf.complex64 or tf.complex128) as tf.cast() make imaginary part 0 while tf.bitcast() +gives module error. +For example, + +Example 1: +```python +>>> a = [1., 2., 3.] +>>> equality_bitcast = tf.bitcast(a,tf.complex128) +tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] +>>> equality_cast = tf.cast(a,tf.complex128) +>>> print(equality_cast) +tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) +``` +Example 2: +```python +>>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) + +``` +Example 3: +```python +>>> x = [1., 2., 3.] +>>> y = [0., 2., 3.] +>>> equality= tf.equal(x,y) +>>> equality_cast = tf.cast(equality,tf.float32) +>>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8) +>>> print(equality) +tf.Tensor([False True True], shape=(3,), dtype=bool) +>>> print(equality_cast) +tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) +>>> print(equality_bitcast) +tf.Tensor( +[[ 0 0 0 0] + [ 0 0 128 63] + [ 0 0 128 63]], shape=(3, 4), dtype=uint8) +``` + +*NOTE*: Bitcast is implemented as a low-level cast, so machines with different +endian orderings will give different results. + +### Operands: +1. `input`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `type` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.BroadcastTo (TF::BroadcastToOp) +Broadcast an array for a compatible shape. + +### Description: + +Broadcasting is the process of making arrays to have compatible shapes +for arithmetic operations. Two shapes are compatible if for each +dimension pair they are either equal or one of them is one. When trying +to broadcast a Tensor to a shape, it starts with the trailing dimensions, +and works its way forward. + +For example, + +```python +>>> x = tf.constant([1, 2, 3]) +>>> y = tf.broadcast_to(x, [3, 3]) +>>> sess.run(y) +array([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]], dtype=int32) +``` + +In the above example, the input Tensor with the shape of `[1, 3]` +is broadcasted to output Tensor with shape of `[3, 3]`. + +### Operands: +1. `input`: tensor of tf.dtype values +1. `shape`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Cast (TF::CastOp) +Cast x of type SrcT to y of DstT. + +### Description: + + +### Operands: +1. `x`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `Truncate` | `BoolAttr` | bool attribute attribute | +| `SrcT` | `Attribute` | derived attribute attribute | +| `DstT` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of tf.dtype values + +## tf.Ceil (TF::CeilOp) +Returns element-wise smallest integer not less than x. + +### Description: + + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point values + +## tf.Concat (TF::ConcatOp) +Concatenates tensors along one dimension. + +### Description: + + +### Operands: +1. `concat_dim`: tensor of 32-bit integer values +1. `values`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.ConcatV2 (TF::ConcatV2Op) +Concatenates tensors along one dimension. + +### Description: + + +### Operands: +1. `values`: tensor of tf.dtype values +1. `axis`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Conj (TF::ConjOp) +Returns the complex conjugate of a complex number. + +### Description: + +Given a tensor `input` of complex numbers, this operation returns a tensor of +complex numbers that are the complex conjugate of each element in `input`. The +complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the +real part and *b* is the imaginary part. + +The complex conjugate returned by this operation is of the form \\(a - bj\\). + +For example: + +``` +# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] +``` + +### Operands: +1. `input`: tensor of complex128 type or complex64 type or TensorFlow variant type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of complex128 type or complex64 type or TensorFlow variant type values + +## tf.Const (TF::ConstOp) +Constant tensor op + +### Description: + + +### Operands: + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | +| `dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Conv2D (TF::Conv2DOp) + +Computes a 2-D convolution given 4-D `input` and `filter` tensors. + + +### Description: + +Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +and a filter / kernel tensor of shape +`[filter_height, filter_width, in_channels, out_channels]`, this op +performs the following: + +1. Flattens the filter to a 2-D matrix with shape + `[filter_height * filter_width * in_channels, output_channels]`. +2. Extracts image patches from the input tensor to form a *virtual* + tensor of shape `[batch, out_height, out_width, + filter_height * filter_width * in_channels]`. +3. For each patch, right-multiplies the filter matrix and the image patch + vector. + +In detail, with the default NHWC format, + + output[b, i, j, k] = + sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * + filter[di, dj, q, k] + +Must have `strides[0] = strides[3] = 1`. For the most common case of the same +horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + +### Operands: +1. `input`: tensor of floating-point values +1. `filter`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `use_cudnn_on_gpu` | `BoolAttr` | bool attribute attribute | +| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID, or EXPLICIT attribute | +| `explicit_paddings` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | +| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.Cos (TF::CosOp) +Computes cos of x element-wise. + +### Description: + + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.DepthwiseConv2dNative (TF::DepthwiseConv2dNativeOp) + +Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. + + +### Description: + +Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +and a filter / kernel tensor of shape +`[filter_height, filter_width, in_channels, channel_multiplier]`, containing +`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies +a different filter to each input channel (expanding from 1 channel to +`channel_multiplier` channels for each), then concatenates the results +together. Thus, the output has `in_channels * channel_multiplier` channels. + +``` +for k in 0..in_channels-1 + for q in 0..channel_multiplier-1 + output[b, i, j, k * channel_multiplier + q] = + sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * + filter[di, dj, k, q] +``` + +Must have `strides[0] = strides[3] = 1`. For the most common case of the same +horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + +### Operands: +1. `input`: tensor of floating-point values +1. `filter`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | +| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | +| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.Div (TF::DivOp) +Returns x / y element-wise. + +### Description: + +*NOTE*: `Div` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Elu (TF::EluOp) + +Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. + + +### Description: + +See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) +](http://arxiv.org/abs/1511.07289) + +### Operands: +1. `features`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `activations`: tensor of floating-point values + +## tf.Equal (TF::EqualOp) +Returns the truth value of (x == y) element-wise. + +### Description: + +*NOTE*: `Equal` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +```python +x = tf.constant([2, 4]) +y = tf.constant(2) +tf.math.equal(x, y) ==> array([True, False]) + +x = tf.constant([2, 4]) +y = tf.constant([2, 4]) +tf.math.equal(x, y) ==> array([True, True]) +``` + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.ExpandDims (TF::ExpandDimsOp) +Inserts a dimension of 1 into a tensor's shape. + +### Description: + +Given a tensor `input`, this operation inserts a dimension of 1 at the +dimension index `axis` of `input`'s shape. The dimension index `axis` starts at +zero; if you specify a negative number for `axis` it is counted backward from +the end. + +This operation is useful if you want to add a batch dimension to a single +element. For example, if you have a single image of shape `[height, width, +channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, +which will make the shape `[1, height, width, channels]`. + +Other examples: + +``` +# 't' is a tensor of shape [2] +shape(expand_dims(t, 0)) ==> [1, 2] +shape(expand_dims(t, 1)) ==> [2, 1] +shape(expand_dims(t, -1)) ==> [2, 1] + +# 't2' is a tensor of shape [2, 3, 5] +shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] +shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] +shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] +``` + +This operation requires that: + +`-1-input.dims() <= dim <= input.dims()` + +This operation is related to `squeeze()`, which removes dimensions of +size 1. + +### Operands: +1. `input`: tensor of tf.dtype values +1. `dim`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tdim` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.FakeQuantWithMinMaxArgs (TF::FakeQuantWithMinMaxArgsOp) + +Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. + + +### Description: + +Attributes `[min; max]` define the clamping range for the `inputs` data. +`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +then de-quantized and output as floats in `[min; max]` interval. +`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + +Before quantization, `min` and `max` values are adjusted with the following +logic. +It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +the behavior can be unexpected: +If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + +Quantization is called fake since the output is still in floating point. + +### Operands: +1. `inputs`: tensor of 32-bit float values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `min` | `FloatAttr` | 32-bit float attribute attribute | +| `max` | `FloatAttr` | 32-bit float attribute attribute | +| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | +| `narrow_range` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `outputs`: tensor of 32-bit float values + +## tf.FakeQuantWithMinMaxVars (TF::FakeQuantWithMinMaxVarsOp) + +Fake-quantize the 'inputs' tensor of type float via global float scalars `min` + + +### Description: + +and `max` to 'outputs' tensor of same shape as `inputs`. + +`[min; max]` define the clamping range for the `inputs` data. +`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +then de-quantized and output as floats in `[min; max]` interval. +`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + +Before quantization, `min` and `max` values are adjusted with the following +logic. +It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +the behavior can be unexpected: +If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + +This operation has a gradient and thus allows for training `min` and `max` +values. + +### Operands: +1. `inputs`: tensor of 32-bit float values +1. `min`: tensor of 32-bit float values +1. `max`: tensor of 32-bit float values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | +| `narrow_range` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `outputs`: tensor of 32-bit float values + +## tf.Fill (TF::FillOp) +Creates a tensor filled with a scalar value. + +### Description: + +This operation creates a tensor of shape `dims` and fills it with `value`. + +For example: + +``` +# Output tensor has shape [2, 3]. +fill([2, 3], 9) ==> [[9, 9, 9] + [9, 9, 9]] +``` + +`tf.fill` differs from `tf.constant` in a few ways: + +* `tf.fill` only supports scalar contents, whereas `tf.constant` supports + Tensor values. +* `tf.fill` creates an Op in the computation graph that constructs the actual + Tensor value at runtime. This is in contrast to `tf.constant` which embeds + the entire Tensor into the graph with a `Const` node. +* Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes + based on other runtime Tensors, unlike `tf.constant`. + +### Operands: +1. `dims`: tensor of 32/64-bit integer values +1. `value`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `index_type` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.FloorDiv (TF::FloorDivOp) +Returns x // y element-wise. + +### Description: + +*NOTE*: `FloorDiv` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Floor (TF::FloorOp) +Returns element-wise largest integer not greater than x. + +### Description: + + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point values + +## tf.FusedBatchNorm (TF::FusedBatchNormOp) +Batch normalization. + +### Description: + +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + +### Operands: +1. `x`: tensor of 32-bit float values +1. `scale`: tensor of 32-bit float values +1. `offset`: tensor of 32-bit float values +1. `mean`: tensor of 32-bit float values +1. `variance`: tensor of 32-bit float values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `epsilon` | `FloatAttr` | 32-bit float attribute attribute | +| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | +| `is_training` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of 32-bit float values +1. `batch_mean`: tensor of 32-bit float values +1. `batch_variance`: tensor of 32-bit float values +1. `reserve_space_1`: tensor of 32-bit float values +1. `reserve_space_2`: tensor of 32-bit float values + +## tf.Gather (TF::GatherOp) +Gather slices from `params` according to `indices`. + +### Description: + +`indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +Produces an output tensor with shape `indices.shape + params.shape[1:]` where: + +```python + # Scalar indices + output[:, ..., :] = params[indices, :, ... :] + + # Vector indices + output[i, :, ..., :] = params[indices[i], :, ... :] + + # Higher rank indices + output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +``` + +If `indices` is a permutation and `len(indices) == params.shape[0]` then +this operation will permute `params` accordingly. + +`validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in +`indices` are always validated to be within range. If assigned to GPU, +out-of-bound indices result in safe but unspecified behavior, which may include +raising an error. + +
+ +
+ +### Operands: +1. `params`: tensor of tf.dtype values +1. `indices`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `validate_indices` | `BoolAttr` | bool attribute attribute | +| `Tindices` | `Attribute` | derived attribute attribute | +| `Tparams` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.GatherV2 (TF::GatherV2Op) + +Gather slices from `params` axis `axis` according to `indices`. + + +### Description: + +`indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +Produces an output tensor with shape `params.shape[:axis] + indices.shape + +params.shape[axis + 1:]` where: + +```python + # Scalar indices (output is rank(params) - 1). + output[a_0, ..., a_n, b_0, ..., b_n] = + params[a_0, ..., a_n, indices, b_0, ..., b_n] + + # Vector indices (output is rank(params)). + output[a_0, ..., a_n, i, b_0, ..., b_n] = + params[a_0, ..., a_n, indices[i], b_0, ..., b_n] + + # Higher rank indices (output is rank(params) + rank(indices) - 1). + output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = + params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] +``` + +
+ +
+ +Note that on CPU, if an out of bound index is found, an error is returned. +On GPU, if an out of bound index is found, a 0 is stored in the +corresponding output value. + +See also `tf.batch_gather` and `tf.gather_nd`. + +### Operands: +1. `params`: tensor of tf.dtype values +1. `indices`: tensor of 32/64-bit integer values +1. `axis`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `batch_dims` | `IntegerAttr` | 64-bit integer attribute attribute | +| `Tindices` | `Attribute` | derived attribute attribute | +| `Tparams` | `Attribute` | derived attribute attribute | +| `Taxis` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.GreaterEqual (TF::GreaterEqualOp) +Returns the truth value of (x >= y) element-wise. + +### Description: + +*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 8/16/32/64-bit integer or floating-point values +1. `y`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.Greater (TF::GreaterOp) +Returns the truth value of (x > y) element-wise. + +### Description: + +*NOTE*: `Greater` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 8/16/32/64-bit integer or floating-point values +1. `y`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.IdentityN (TF::IdentityNOp) + +Returns a list of tensors with the same shapes and contents as the input + + +### Description: + +tensors. + +This op can be used to override the gradient for complicated functions. For +example, suppose y = f(x) and we wish to apply a custom function g for backprop +such that dx = g(dy). In Python, + +```python +with tf.get_default_graph().gradient_override_map( + {'IdentityN': 'OverrideGradientWithG'}): + y, _ = identity_n([f(x), x]) + +@tf.RegisterGradient('OverrideGradientWithG') +def ApplyG(op, dy, _): + return [None, g(dy)] # Do not backprop to f(x). +``` + +### Operands: +1. `input`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Identity (TF::IdentityOp) +Identity op + +### Description: + +Returns a tensor with the same shape and contents as input. + +### Operands: +1. `input`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Invert (TF::InvertOp) + +Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010. + + +### Description: + +Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101. +This operation is performed on each element of the tensor argument `x`. + +### Operands: +1. `x`: tensor of 8/16/32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of 8/16/32/64-bit integer values + +## tf.LeakyRelu (TF::LeakyReluOp) +Computes rectified linear: `max(features, features * alpha)`. + +### Description: + + +### Operands: +1. `features`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `alpha` | `FloatAttr` | 32-bit float attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `activations`: tensor of floating-point values + +## tf.LessEqual (TF::LessEqualOp) +Returns the truth value of (x <= y) element-wise. + +### Description: + +*NOTE*: `LessEqual` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 8/16/32/64-bit integer or floating-point values +1. `y`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.Less (TF::LessOp) +Returns the truth value of (x < y) element-wise. + +### Description: + +*NOTE*: `Less` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 8/16/32/64-bit integer or floating-point values +1. `y`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.Log (TF::LogOp) +Computes natural logarithm of x element-wise. + +### Description: + +I.e., \\(y = \log_e x\\). + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.LogSoftmax (TF::LogSoftmaxOp) +Computes log softmax activations. + +### Description: + +For each batch `i` and class `j` we have + + logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) + +### Operands: +1. `logits`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `logsoftmax`: tensor of floating-point values + +## tf.LogicalAnd (TF::LogicalAndOp) +Returns the truth value of x AND y element-wise. + +### Description: + +*NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 1-bit integer values +1. `y`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.LogicalNot (TF::LogicalNotOp) +Returns the truth value of NOT x element-wise. + +### Description: + + +### Operands: +1. `x`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `y`: tensor of 1-bit integer values + +## tf.LogicalOr (TF::LogicalOrOp) +Returns the truth value of x OR y element-wise. + +### Description: + +*NOTE*: `LogicalOr` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 1-bit integer values +1. `y`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.MatMul (TF::MatMulOp) + +Multiply the matrix "a" by the matrix "b". + + +### Description: + +The inputs must be two-dimensional matrices and the inner dimension of +"a" (after being transposed if transpose_a is true) must match the +outer dimension of "b" (after being transposed if transposed_b is +true). + +*Note*: The default kernel implementation for MatMul on GPUs uses +cublas. + +### Operands: +1. `a`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values +1. `b`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `transpose_a` | `BoolAttr` | bool attribute attribute | +| `transpose_b` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `product`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +## tf.Max (TF::MaxOp) + +Computes the maximum of elements across dimensions of a tensor. + + +### Description: + +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +### Operands: +1. `input`: tensor of number values +1. `reduction_indices`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.MaxPool (TF::MaxPoolOp) +Performs max pooling on the input. + +### Description: + + +### Operands: +1. `input`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | +| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | +| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | +| `data_format` | `StringAttr` | string attribute whose value is NHWC, or NCHW, or NCHW_VECT_C attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of 8/16/32/64-bit integer or floating-point values + +## tf.Maximum (TF::MaximumOp) +Returns the max of x and y (i.e. x > y ? x : y) element-wise. + +### Description: + +*NOTE*: `Maximum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of floating-point or 32/64-bit integer values +1. `y`: tensor of floating-point or 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of floating-point or 32/64-bit integer values + +## tf.Mean (TF::MeanOp) +Computes the mean of elements across dimensions of a tensor. + +### Description: + +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +### Operands: +1. `input`: tensor of number values +1. `reduction_indices`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.Min (TF::MinOp) + +Computes the minimum of elements across dimensions of a tensor. + + +### Description: + +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +### Operands: +1. `input`: tensor of number values +1. `reduction_indices`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.Minimum (TF::MinimumOp) +Returns the min of x and y (i.e. x < y ? x : y) element-wise. + +### Description: + +*NOTE*: `Minimum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of floating-point or 32/64-bit integer values +1. `y`: tensor of floating-point or 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of floating-point or 32/64-bit integer values + +## tf.MulNoNan (TF::MulNoNanOp) + +Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. + + +### Description: + +*NOTE*: `MulNoNan` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values +1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values + +## tf.Mul (TF::MulOp) +Returns x * y element-wise. + +### Description: + +*NOTE*: `Multiply` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Neg (TF::NegOp) +Computes numerical negative value element-wise. + +### Description: + +I.e., \\(y = -x\\). + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +## tf.NoOp (TF::NoOp) +Does nothing. Only useful as a placeholder for control edges. + +### Description: + + +### Operands: + +### Attributes: + +### Results: + +## tf.NotEqual (TF::NotEqualOp) +Returns the truth value of (x != y) element-wise. + +### Description: + +*NOTE*: `NotEqual` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 1-bit integer values + +## tf.Pack (TF::PackOp) + +Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. + + +### Description: + +Packs the `N` tensors in `values` into a tensor with rank one higher than each +tensor in `values`, by packing them along the `axis` dimension. +Given a list of tensors of shape `(A, B, C)`; + +if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. +if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. +Etc. + +For example: + +``` +# 'x' is [1, 4] +# 'y' is [2, 5] +# 'z' is [3, 6] +pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] +``` + +This is the opposite of `unpack`. + +### Operands: +1. `values`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | +| `axis` | `IntegerAttr` | 64-bit integer attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Pad (TF::PadOp) +Pads a tensor with zeros. + +### Description: + +This operation pads a `input` with zeros according to the `paddings` you +specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +how many zeros to add before the contents of `input` in that dimension, and +`paddings[D, 1]` indicates how many zeros to add after the contents of `input` +in that dimension. + +The padded size of each dimension D of the output is: + +`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + +For example: + +``` +# 't' is [[1, 1], [2, 2]] +# 'paddings' is [[1, 1], [2, 2]] +# rank of 't' is 2 +pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + [0, 0, 1, 1, 0, 0] + [0, 0, 2, 2, 0, 0] + [0, 0, 0, 0, 0, 0]] +``` + +### Operands: +1. `input`: tensor of tf.dtype values +1. `paddings`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tpaddings` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.PadV2 (TF::PadV2Op) +Pads a tensor. + +### Description: + +This operation pads `input` according to the `paddings` and `constant_values` +you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is +the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +how many padding values to add before the contents of `input` in that dimension, +and `paddings[D, 1]` indicates how many padding values to add after the contents +of `input` in that dimension. `constant_values` is a scalar tensor of the same +type as `input` that indicates the value to use for padding `input`. + +The padded size of each dimension D of the output is: + +`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + +For example: + +``` +# 't' is [[1, 1], [2, 2]] +# 'paddings' is [[1, 1], [2, 2]] +# 'constant_values' is 0 +# rank of 't' is 2 +pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + [0, 0, 1, 1, 0, 0] + [0, 0, 2, 2, 0, 0] + [0, 0, 0, 0, 0, 0]] +``` + +### Operands: +1. `input`: tensor of tf.dtype values +1. `paddings`: tensor of 32/64-bit integer values +1. `constant_values`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tpaddings` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Placeholder.input (TF::PlaceholderInputOp) +PlaceholderInput op + +### Description: + +Inserts a placeholder for a tensor that will be always fed. + +### Operands: +1. `arg`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `min` | `FloatAttr` | 32-bit float attribute attribute | +| `max` | `FloatAttr` | 32-bit float attribute attribute | +| `type` | `TypeAttr` | integer type attribute | +| `dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Placeholder (TF::PlaceholderOp) +Placeholder op + +### Description: + +Inserts a placeholder for a tensor that will be always fed. + +### Operands: + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.QuantizeAndDequantize (TF::QuantizeAndDequantizeOp) +Use QuantizeAndDequantizeV2 instead. + +### Description: + + +### Operands: +1. `input`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `signed_input` | `BoolAttr` | bool attribute attribute | +| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | +| `range_given` | `BoolAttr` | bool attribute attribute | +| `input_min` | `FloatAttr` | 32-bit float attribute attribute | +| `input_max` | `FloatAttr` | 32-bit float attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.QuantizeAndDequantizeV2 (TF::QuantizeAndDequantizeV2Op) +Quantizes then dequantizes a tensor. + +### Description: + +This op simulates the precision loss from the quantized forward pass by: + +1. Quantizing the tensor to fixed point numbers, which should match the target + quantization method when it is used in inference. +2. Dequantizing it back to floating point numbers for the following ops, most + likely matmul. + +There are different ways to quantize. This version uses only scaling, so 0.0 +maps to 0. + +From the specified 'num_bits' in the quantized output type, it determines +minimum and maximum representable quantized values. + +e.g. + +* [-128, 127] for signed, num_bits = 8, or +* [0, 255] for unsigned, num_bits = 8. + +If range_given == False, the initial input_min, input_max will be determined +automatically as the minimum and maximum values in the input tensor, otherwise +the specified values of input_min, input_max are used. + +Note: If the input_min, input_max are specified, they do not need to equal the +actual minimum and maximum values in the tensor. e.g. in some cases it may be +beneficial to specify these values such that the low probability extremes of the +input distribution are clipped. + +This op determines the maximum scale_factor that would map the initial +[input_min, input_max] range to a range that lies within the representable +quantized range. + +It determines the scale from one of input_min and input_max, then updates the +other one to maximize the respresentable range. + +e.g. + +* if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it + would update input_max to be 127 / 12.8 = 9.921875 +* if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it + would update input_min to be 128.0 / 12.7 = -10.07874 +* if the output is unsigned, input_min is forced to be 0, and only the + specified input_max is used. + +After determining the scale_factor and updating the input range, it applies the +following to each value in the 'input' tensor. + +output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. + +The above round function rounds the value based on the given round_mode. + +### Operands: +1. `input`: tensor of floating-point values +1. `input_min`: tensor of floating-point values +1. `input_max`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `signed_input` | `BoolAttr` | bool attribute attribute | +| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | +| `range_given` | `BoolAttr` | bool attribute attribute | +| `round_mode` | `StringAttr` | string attribute whose value is HALF_TO_EVEN, or HALF_UP attribute | +| `narrow_range` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.QuantizeAndDequantizeV3 (TF::QuantizeAndDequantizeV3Op) +Quantizes then dequantizes a tensor. + +### Description: + +This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +tensor, so its value can change during training. + +### Operands: +1. `input`: tensor of floating-point values +1. `input_min`: tensor of floating-point values +1. `input_max`: tensor of floating-point values +1. `num_bits`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `signed_input` | `BoolAttr` | bool attribute attribute | +| `range_given` | `BoolAttr` | bool attribute attribute | +| `narrow_range` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.RandomUniform (TF::RandomUniformOp) +Outputs random values from a uniform distribution. + +### Description: + +The generated values follow a uniform distribution in the range `[0, 1)`. The +lower bound 0 is included in the range, while the upper bound 1 is excluded. + +### Operands: +1. `shape`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `seed` | `IntegerAttr` | 64-bit integer attribute attribute | +| `seed2` | `IntegerAttr` | 64-bit integer attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of floating-point values + +## tf.Range (TF::RangeOp) +Creates a sequence of numbers. + +### Description: + +This operation creates a sequence of numbers that begins at `start` and +extends by increments of `delta` up to but not including `limit`. + +For example: + +``` +# 'start' is 3 +# 'limit' is 18 +# 'delta' is 3 +tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] +``` + +### Operands: +1. `start`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values +1. `limit`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values +1. `delta`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values + +## tf.Rank (TF::RankOp) +Returns the rank of a tensor. + +### Description: + +This operation returns an integer representing the rank of `input`. + +For example: + +``` +# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +# shape of tensor 't' is [2, 2, 3] +rank(t) ==> 3 +``` + +**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank +of a tensor is the number of indices required to uniquely select each element +of the tensor. Rank is also known as "order", "degree", or "ndims." + +### Operands: +1. `input`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of 32-bit integer values + +## tf.RealDiv (TF::RealDivOp) +Returns x / y element-wise for real types. + +### Description: + +If `x` and `y` are reals, this will return the floating-point division. + +*NOTE*: `Div` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Reciprocal (TF::ReciprocalOp) +Computes the reciprocal of x element-wise. + +### Description: + +I.e., \\(y = 1 / x\\). + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +## tf.Relu6 (TF::Relu6Op) +Computes rectified linear 6: `min(max(features, 0), 6)`. + +### Description: + + +### Operands: +1. `features`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values + +## tf.Relu (TF::ReluOp) +Computes rectified linear: `max(features, 0)`. + +### Description: + + +### Operands: +1. `features`: tensor of 8/16/32/64-bit integer or floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values + +## tf.Reshape (TF::ReshapeOp) +Reshapes a tensor. + +### Description: + +Given `tensor`, this operation returns a tensor that has the same values +as `tensor` with shape `shape`. + +If one component of `shape` is the special value -1, the size of that dimension +is computed so that the total size remains constant. In particular, a `shape` +of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. + +If `shape` is 1-D or higher, then the operation returns a tensor with shape +`shape` filled with the values of `tensor`. In this case, the number of elements +implied by `shape` must be the same as the number of elements in `tensor`. + +For example: + +``` +# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] +# tensor 't' has shape [9] +reshape(t, [3, 3]) ==> [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]] + +# tensor 't' is [[[1, 1], [2, 2]], +# [[3, 3], [4, 4]]] +# tensor 't' has shape [2, 2, 2] +reshape(t, [2, 4]) ==> [[1, 1, 2, 2], + [3, 3, 4, 4]] + +# tensor 't' is [[[1, 1, 1], +# [2, 2, 2]], +# [[3, 3, 3], +# [4, 4, 4]], +# [[5, 5, 5], +# [6, 6, 6]]] +# tensor 't' has shape [3, 2, 3] +# pass '[-1]' to flatten 't' +reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] + +# -1 can also be used to infer the shape + +# -1 is inferred to be 9: +reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6]] +# -1 is inferred to be 2: +reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6]] +# -1 is inferred to be 3: +reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], + [2, 2, 2], + [3, 3, 3]], + [[4, 4, 4], + [5, 5, 5], + [6, 6, 6]]] + +# tensor 't' is [7] +# shape `[]` reshapes to a scalar +reshape(t, []) ==> 7 +``` + +### Operands: +1. `tensor`: tensor of tf.dtype values +1. `shape`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tshape` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.ResizeBilinear (TF::ResizeBilinearOp) +Resize `images` to `size` using bilinear interpolation. + +### Description: + +Input images can be of different types but output images are always float. + +### Operands: +1. `images`: tensor of 8/16/32/64-bit integer or floating-point values +1. `size`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `align_corners` | `BoolAttr` | bool attribute attribute | +| `half_pixel_centers` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `resized_images`: tensor of 32-bit float values + +## tf.ReverseV2 (TF::ReverseV2Op) +Reverses specific dimensions of a tensor. + +### Description: + +NOTE `tf.reverse` has now changed behavior in preparation for 1.0. +`tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. + +Given a `tensor`, and a `int32` tensor `axis` representing the set of +dimensions of `tensor` to reverse. This operation reverses each dimension +`i` for which there exists `j` s.t. `axis[j] == i`. + +`tensor` can have up to 8 dimensions. The number of dimensions specified +in `axis` may be 0 or more entries. If an index is specified more than +once, a InvalidArgument error is raised. + +For example: + +``` +# tensor 't' is [[[[ 0, 1, 2, 3], +# [ 4, 5, 6, 7], +# [ 8, 9, 10, 11]], +# [[12, 13, 14, 15], +# [16, 17, 18, 19], +# [20, 21, 22, 23]]]] +# tensor 't' shape is [1, 2, 3, 4] + +# 'dims' is [3] or 'dims' is [-1] +reverse(t, dims) ==> [[[[ 3, 2, 1, 0], + [ 7, 6, 5, 4], + [ 11, 10, 9, 8]], + [[15, 14, 13, 12], + [19, 18, 17, 16], + [23, 22, 21, 20]]]] + +# 'dims' is '[1]' (or 'dims' is '[-3]') +reverse(t, dims) ==> [[[[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23] + [[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]]]] + +# 'dims' is '[2]' (or 'dims' is '[-2]') +reverse(t, dims) ==> [[[[8, 9, 10, 11], + [4, 5, 6, 7], + [0, 1, 2, 3]] + [[20, 21, 22, 23], + [16, 17, 18, 19], + [12, 13, 14, 15]]]] +``` + +### Operands: +1. `tensor`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values +1. `axis`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values + +## tf.Rsqrt (TF::RsqrtOp) +Computes reciprocal of square root of x element-wise. + +### Description: + +I.e., \\(y = 1 / \sqrt{x}\\). + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.Select (TF::SelectOp) +Selects elements from `x` or `y`, depending on `condition`. + +### Description: + +The `x`, and `y` tensors must all have the same shape, and the +output will also have that shape. + +The `condition` tensor must be a scalar if `x` and `y` are scalars. +If `x` and `y` are vectors or higher rank, then `condition` must be either a +scalar, a vector with size matching the first dimension of `x`, or must have +the same shape as `x`. + +The `condition` tensor acts as a mask that chooses, based on the value at each +element, whether the corresponding element / row in the output should be +taken from `x` (if true) or `y` (if false). + +If `condition` is a vector and `x` and `y` are higher rank matrices, then +it chooses which row (outer dimension) to copy from `x` and `y`. +If `condition` has the same shape as `x` and `y`, then it chooses which +element to copy from `x` and `y`. + +For example: + +```python +# 'condition' tensor is [[True, False] +# [False, True]] +# 't' is [[1, 2], +# [3, 4]] +# 'e' is [[5, 6], +# [7, 8]] +select(condition, t, e) # => [[1, 6], [7, 4]] + + +# 'condition' tensor is [True, False] +# 't' is [[1, 2], +# [3, 4]] +# 'e' is [[5, 6], +# [7, 8]] +select(condition, t, e) ==> [[1, 2], + [7, 8]] + +``` + +### Operands: +1. `condition`: tensor of 1-bit integer values +1. `t`: tensor of tf.dtype values +1. `e`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Shape (TF::ShapeOp) +Returns the shape of a tensor. + +### Description: + +This operation returns a 1-D integer tensor representing the shape of `input`. + +For example: + +``` +# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +shape(t) ==> [2, 2, 3] +``` + +### Operands: +1. `input`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `out_type` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of 32/64-bit integer values + +## tf.Sigmoid (TF::SigmoidOp) +Computes sigmoid of `x` element-wise. + +### Description: + +Specifically, `y = 1 / (1 + exp(-x))`. + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.Sin (TF::SinOp) +Computes sin of x element-wise. + +### Description: + + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.Slice (TF::SliceOp) +Return a slice from 'input'. + +### Description: + +The output tensor is a tensor with dimensions described by 'size' +whose values are extracted from 'input' starting at the offsets in +'begin'. + +*Requirements*: + 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) + +### Operands: +1. `input`: tensor of tf.dtype values +1. `begin`: tensor of 32/64-bit integer values +1. `size`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Index` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Softmax (TF::SoftmaxOp) +Computes softmax activations. + +### Description: + +For each batch `i` and class `j` we have + + $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ + +### Operands: +1. `logits`: tensor of floating-point values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `softmax`: tensor of floating-point values + +## tf.SpaceToBatchND (TF::SpaceToBatchNDOp) +SpaceToBatch for N-D tensors of type T. + +### Description: + +This operation divides "spatial" dimensions `[1, ..., M]` of the input into a +grid of blocks of shape `block_shape`, and interleaves these blocks with the +"batch" dimension (0) such that in the output, the spatial dimensions +`[1, ..., M]` correspond to the position within the grid, and the batch +dimension combines both the position within a spatial block and the original +batch position. Prior to division into blocks, the spatial dimensions of the +input are optionally zero padded according to `paddings`. See below for a +precise description. + +### Operands: +1. `input`: tensor of tf.dtype values +1. `block_shape`: tensor of 32/64-bit integer values +1. `paddings`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tpaddings` | `Attribute` | derived attribute attribute | +| `Tblock_shape` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Split (TF::SplitOp) +Splits a tensor into `num_split` tensors along one dimension. + +### Description: + + +### Operands: +1. `split_dim`: tensor of 32-bit integer values +1. `value`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.SplitV (TF::SplitVOp) +Splits a tensor into `num_split` tensors along one dimension. + +### Description: + + +### Operands: +1. `value`: tensor of tf.dtype values +1. `size_splits`: tensor of 32/64-bit integer values +1. `split_dim`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | +| `Tlen` | `Attribute` | derived attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Sqrt (TF::SqrtOp) +Computes square root of x element-wise. + +### Description: + +I.e., \\(y = \sqrt{x} = x^{1/2}\\). + +### Operands: +1. `x`: tensor of floating-point or 64/128-bit complex type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of floating-point or 64/128-bit complex type values + +## tf.Square (TF::SquareOp) +Computes square of x element-wise. + +### Description: + +I.e., \\(y = x * x = x^2\\). + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +## tf.SquaredDifference (TF::SquaredDifferenceOp) +Returns (x - y)(x - y) element-wise. + +### Description: + +*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values +1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values + +## tf.Squeeze (TF::SqueezeOp) +Removes dimensions of size 1 from the shape of a tensor. + +### Description: + +Given a tensor `input`, this operation returns a tensor of the same type with +all dimensions of size 1 removed. If you don't want to remove all size 1 +dimensions, you can remove specific size 1 dimensions by specifying +`axis`. + +For example: + +``` +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t)) ==> [2, 3] +``` + +Or, to remove specific size 1 dimensions: + +``` +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +``` + +### Operands: +1. `input`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.StridedSlice (TF::StridedSliceOp) +Return a strided slice from `input`. + +### Description: + +Note, most python users will want to use the Python `Tensor.__getitem__` +or `Variable.__getitem__` rather than this op directly. + +The goal of this op is to produce a new tensor with a subset of +the elements from the `n` dimensional `input` tensor. The subset is chosen using +a sequence of `m` sparse range specifications encoded into the arguments +of this function. Note, in some cases +`m` could be equal to `n`, but this need not be the case. Each +range specification entry can be one of the following: + +- An ellipsis (...). Ellipses are used to imply zero or more + dimensions of full-dimension selection and are produced using + `ellipsis_mask`. For example, `foo[...]` is the identity slice. + +- A new axis. This is used to insert a new shape=1 dimension and is + produced using `new_axis_mask`. For example, `foo[:, ...]` where + `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + + +- A range `begin:end:stride`. This is used to specify how much to choose from + a given dimension. `stride` can be any integer but 0. `begin` is an integer + which represents the index of the first value to select while `end` represents + the index of the last value to select. The number of values selected in each + dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. + `begin` and `end` can be negative where `-1` is the last element, `-2` is + the second to last. `begin_mask` controls whether to replace the explicitly + given `begin` with an implicit effective value of `0` if `stride > 0` and + `-1` if `stride < 0`. `end_mask` is analogous but produces the number + required to create the largest open interval. For example, given a shape + `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do + not assume this is equivalent to `foo[0:-1]` which has an effective `begin` + and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the + first dimension of a tensor while dropping the last two (in the original + order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + +- A single index. This is used to keep only elements that have a given + index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a + shape `(6,)` tensor. This is encoded in `begin` and `end` and + `shrink_axis_mask`. + +Each conceptual range specification is encoded in the op's argument. This +encoding is best understand by considering a non-trivial example. In +particular, +`foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as + +``` +begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) +end = [2, 4, x, x, -3, x] +strides = [1, 1, x, x, -1, 1] +begin_mask = 1<<4 | 1 << 5 = 48 +end_mask = 1<<5 = 32 +ellipsis_mask = 1<<3 = 8 +new_axis_mask = 1<<2 4 +shrink_axis_mask = 1<<0 +``` + +In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of +the slice becomes (2, 1, 5, 5, 2, 5). +Let us walk step by step through each argument specification. + +1. The first argument in the example slice is turned into `begin = 1` and +`end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we +also set the appropriate bit in `shrink_axis_mask`. + +2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have +zero bits contributed. + +3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 +dimension in the final shape. Dummy values are contributed to begin, +end and stride, while the new_axis_mask bit is set. + +4. `...` grab the full ranges from as many dimensions as needed to +fully specify a slice for every dimension of the input shape. + +5. `:-3:-1` shows the use of negative indices. A negative index `i` associated +with a dimension that has shape `s` is converted to a positive index +`s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion +is done internally so begin, end and strides receive x, -3, and -1. +The appropriate begin_mask bit is set to indicate the start range is the +full range (ignoring the x). + +6. `:` indicates that the entire contents of the corresponding dimension +is selected. This is equivalent to `::` or `0::1`. begin, end, and strides +receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and +`end_mask` are also set. + +*Requirements*: + `0 != strides[i] for i in [0, m)` + `ellipsis_mask must be a power of two (only one ellipsis)` + +### Operands: +1. `input`: tensor of tf.dtype values +1. `begin`: tensor of 32/64-bit integer values +1. `end`: tensor of 32/64-bit integer values +1. `strides`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `begin_mask` | `IntegerAttr` | 64-bit integer attribute attribute | +| `end_mask` | `IntegerAttr` | 64-bit integer attribute attribute | +| `ellipsis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | +| `new_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | +| `shrink_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Index` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Sub (TF::SubOp) +Returns x - y element-wise. + +### Description: + +*NOTE*: `Subtract` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Sum (TF::SumOp) +Computes the sum of elements across dimensions of a tensor. + +### Description: + +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + +### Operands: +1. `input`: tensor of number values +1. `reduction_indices`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | +| `Tidx` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of number values + +## tf.TensorListFromTensor (TF::TensorListFromTensorOp) + +Creates a TensorList which, when stacked, has the value of `tensor`. + + +### Description: + +Each tensor in the result list corresponds to one row of the input tensor. + +tensor: The input tensor. +output_handle: The list. + +### Operands: +1. `tensor`: tensor of tf.dtype values +1. `element_shape`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `shape_type` | `Attribute` | derived attribute attribute | +| `element_dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output_handle`: tensor of TensorFlow variant type values + +## tf.TensorListGetItem (TF::TensorListGetItemOp) + + +### Description: + + +### Operands: +1. `input_handle`: tensor of TensorFlow variant type values +1. `index`: tensor of 32-bit integer values +1. `element_shape`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `element_dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `item`: tensor of tf.dtype values + +## tf.TensorListReserve (TF::TensorListReserveOp) +List of the given size with empty elements. + +### Description: + +element_shape: the shape of the future elements of the list +num_elements: the number of elements to reserve +handle: the output list +element_dtype: the desired type of elements in the list. + +### Operands: +1. `element_shape`: tensor of 32/64-bit integer values +1. `num_elements`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `element_dtype` | `TypeAttr` | any type attribute attribute | +| `shape_type` | `Attribute` | derived attribute attribute | + +### Results: +1. `handle`: tensor of TensorFlow variant type values + +## tf.TensorListSetItem (TF::TensorListSetItemOp) + + +### Description: + + +### Operands: +1. `input_handle`: tensor of TensorFlow variant type values +1. `index`: tensor of 32-bit integer values +1. `item`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `element_dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `output_handle`: tensor of TensorFlow variant type values + +## tf.TensorListStack (TF::TensorListStackOp) +Stacks all tensors in the list. + +### Description: + +Requires that all tensors have the same shape. + +input_handle: the input list +tensor: the gathered result +num_elements: optional. If not -1, the number of elements in the list. + +### Operands: +1. `input_handle`: tensor of TensorFlow variant type values +1. `element_shape`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_elements` | `IntegerAttr` | 64-bit integer attribute attribute | +| `element_dtype` | `Attribute` | derived attribute attribute | + +### Results: +1. `tensor`: tensor of tf.dtype values + +## tf.TopKV2 (TF::TopKV2Op) + +Finds values and indices of the `k` largest elements for the last dimension. + + +### Description: + +If the input is a vector (rank-1), finds the `k` largest entries in the vector +and outputs their values and indices as vectors. Thus `values[j]` is the +`j`-th largest entry in `input`, and its index is `indices[j]`. + +For matrices (resp. higher rank input), computes the top `k` entries in each +row (resp. vector along the last dimension). Thus, + + values.shape = indices.shape = input.shape[:-1] + [k] + +If two elements are equal, the lower-index element appears first. + +### Operands: +1. `input`: tensor of 8/16/32/64-bit integer or floating-point values +1. `k`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `sorted` | `BoolAttr` | bool attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `values`: tensor of 8/16/32/64-bit integer or floating-point values +1. `indices`: tensor of 32-bit integer values + +## tf.Transpose (TF::TransposeOp) +Shuffle dimensions of x according to a permutation. + +### Description: + +The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + +### Operands: +1. `x`: tensor of tf.dtype values +1. `perm`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | +| `Tperm` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of tf.dtype values + +## tf.TruncateDiv (TF::TruncateDivOp) +Returns x / y element-wise for integer types. + +### Description: + +Truncation designates that negative numbers will round fractional quantities +toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different +than Python semantics. See `FloorDiv` for a division function that matches +Python Semantics. + +*NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +### Operands: +1. `x`: tensor of number values +1. `y`: tensor of number values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of number values + +## tf.Unpack (TF::UnpackOp) + +Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. + + +### Description: + +Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +For example, given a tensor of shape `(A, B, C, D)`; + +If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` + and each tensor in `output` will have shape `(B, C, D)`. (Note that the + dimension unpacked along is gone, unlike `split`). + +If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` + and each tensor in `output` will have shape `(A, C, D)`. +Etc. + +This is the opposite of `pack`. + +### Operands: +1. `value`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 0 attribute | +| `axis` | `IntegerAttr` | 64-bit integer attribute attribute | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of tf.dtype values + +## tf.Xdivy (TF::XdivyOp) +Returns 0 if x == 0, and x / y otherwise, elementwise. + +### Description: + + +### Operands: +1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values +1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values + +## tf.ZerosLike (TF::ZerosLikeOp) +Returns a tensor of zeros with the same shape and type as x. + +### Description: + + +### Operands: +1. `x`: tensor of tf.dtype values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `T` | `Attribute` | derived attribute attribute | + +### Results: +1. `y`: tensor of tf.dtype values + diff --git a/tensorflow/compiler/mlir/g3doc/tfl_ops.md b/tensorflow/compiler/mlir/g3doc/tfl_ops.md new file mode 100644 index 00000000000..74e4fc47868 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/tfl_ops.md @@ -0,0 +1,1606 @@ + +# Operation definition +## tfl.abs (TFL::AbsOp) +Absolute value operator + +### Description: + +Given a tensor `x`, this operation returns a tensor containing the absolute +value of each element in `x`. For example, if x is an input element and y is +an output element, this operation computes \\(y = |x|\\). + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.add_n (TFL::AddNOp) +add_n operator + +### Description: + +Adds all input tensors element-wise. + +### Operands: +1. `inputs`: tensor of 32-bit float or 32-bit integer values + +### Attributes: + +### Results: +1. `sum`: tensor of 32-bit float or 32-bit integer values + +## tfl.add (TFL::AddOp) +Addition operator + +### Description: + +Element-wise addition operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.average_pool_2d (TFL::AveragePool2DOp) +Average_pool_2d operator + +### Description: + +Performs average-pooling operation on input. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute | +| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute | +| `padding` | `StringAttr` | padding enum attribute | +| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | +| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.batch_to_space_nd (TFL::BatchToSpaceNdOp) +BatchToSpaceNd operator + +### Description: + +This operation reshapes the "batch" dimension 0 into space dimensions. + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `block_shape`: tensor of 32-bit integer values +1. `indices`: tensor of 32-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values + +## tfl.ceil (TFL::CeilOp) +Ceil operator + +### Description: + +Returns element-wise ceil value of the input. + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of floating-point values + +## tfl.concatenation (TFL::ConcatenationOp) +Concatenation operator + +### Description: + +Concatenates tensors along one dimension + +### Operands: +1. `values`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values + +## tfl.pseudo_const (TFL::ConstOp) +Constant pseudo op. + +### Description: + +Represents a constant value in TensorFlow Lite dialect. This is not an +actual operation and it will be lowered to buffer instead. + +The op is allowed to have all the same type of attributes as tf.Const does +(e.g., opaque TF attributes are allowed). + +### Operands: + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.conv_2d (TFL::Conv2DOp) +Convolution operator + +### Description: + +Performs convolution operation on inputs. + +Inputs: + `inputs[0]`: required: the input activation tensor + `inputs[1]`: required: the filter weight tensor + `inputs[2]`: optional: the bias tensor + +### Operands: +1. `input`: tensor of any type values +1. `filter`: tensor of any type values +1. `bias`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute | +| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | +| `padding` | `StringAttr` | padding enum attribute | +| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | +| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.cos (TFL::CosOp) +Cosine operator + +### Description: + +Computes element-wise Cosine of input + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of floating-point values + +## tfl.depthwise_conv_2d (TFL::DepthwiseConv2DOp) +Depthwise-separable convolution operator + +### Description: + +Performs convolution operation on inputs. + +Inputs: + `inputs[0]`: required: the input activation tensor + `inputs[1]`: required: the filter weight tensor + `inputs[2]`: optional: the bias tensor + +### Operands: +1. `input`: tensor of any type values +1. `filter`: tensor of any type values +1. `bias`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute | +| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | +| `padding` | `StringAttr` | padding enum attribute | +| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | +| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | +| `depth_multiplier` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.dequantize (TFL::DequantizeOp) +Dequantize operator + +### Description: + +Converts quantized array of integers to floating-points according to the +quantization parameters. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.div (TFL::DivOp) +Division operator + +### Description: + +Element-wise division operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.elu (TFL::EluOp) +Exponential Linear Unit operator + +### Description: + +Computes the exponential linear + f(x) -> exp(x) - 1 for x < 0, x for x >= 0. +element-wise. + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.equal (TFL::EqualOp) +Equal operator + +### Description: + +Returns the truth element of x == y element-wise + +### Operands: +1. `x`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values +1. `y`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.exp (TFL::ExpOp) +Natural exponentiation operator + +### Description: + +Performs element-wise natural exponentiation operation on input. + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.expand_dims (TFL::ExpandDimsOp) +Inserts a dimension of 1 into a tensor's shape. + +### Description: + +Given a tensor `input`, this operation inserts a dimension of 1 at the +dimension index `axis` of `input`'s shape. The dimension index `axis` starts at +zero; if you specify a negative number for `axis` it is counted backward from +the end. + +This operation is useful if you want to add a batch dimension to a single +element. For example, if you have a single image of shape `[height, width, +channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, +which will make the shape `[1, height, width, channels]`. + +Other examples: + +``` +# 't' is a tensor of shape [2] +shape(expand_dims(t, 0)) ==> [1, 2] +shape(expand_dims(t, 1)) ==> [2, 1] +shape(expand_dims(t, -1)) ==> [2, 1] + +# 't2' is a tensor of shape [2, 3, 5] +shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] +shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] +shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] +``` + +This operation requires that: + +`-1-input.dims() <= dim <= input.dims()` + +This operation is related to `squeeze()`, which removes dimensions of +size 1. + +### Operands: +1. `input`: tensor of any type values +1. `dim`: tensor of any integer type + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.fake_quant (TFL::FakeQuantOp) +FakeQuant operator + +### Description: + +Fake-quantize the 'inputs' tensor of type float via float scalars min and +max to 'outputs' tensor of same shape as inputs. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `minmax` | `ArrayAttr` | min-max range pair attribute | +| `num_bits` | `IntegerAttr` | 32-bit integer attribute attribute | +| `narrow_range` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.fill (TFL::FillOp) +Fill the tensor with given value. + +### Description: + +Fill the tensor with given value. + +### Operands: +1. `dims`: tensor of 32/64-bit integer values +1. `value`: tensor of any type values + +### Attributes: + +### Results: +1. `res`: tensor of any type values + +## tfl.floor_div (TFL::FloorDivOp) +Floor div operator + +### Description: + +Element-wise floor div operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.floor_mod (TFL::FloorModOp) +Division reminder + +### Description: + +Element-wise division reminder operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.floor (TFL::FloorOp) +Floor operator + +### Description: + +Returns element-wise floor value of the input. + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of floating-point values + +## tfl.fully_connected (TFL::FullyConnectedOp) +Fully connected op + +### Description: + + +### Operands: +1. `input`: tensor of 32-bit float values +1. `filter`: tensor of 32-bit float values +1. `bias`: tensor of 32-bit float values or none type + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | +| `weights_format` | `StringAttr` | fully connected options weights format attribute | +| `keep_num_dims` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float values + +## tfl.gather (TFL::GatherOp) +Gather operator + +### Description: + +Gather slices from `params` axis `axis` according to `indices`. + +### Operands: +1. `params`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer or TFLite string type values +1. `indices`: tensor of 32-bit integer or 64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or TFLite string type values + +## tfl.greater_equal (TFL::GreaterEqualOp) +Greater_equal operator + +### Description: + +Element-wise greater_equal operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.greater (TFL::GreaterOp) +Greater operator + +### Description: + +Element-wise greater operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.pseudo_input (TFL::InputOp) +Input pseudo operator + +### Description: + +Takes one of the function arguments as input and returns it as result. This +is a NOP and is used to attach attributes such as tensor name to function +arguments. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.leaky_relu (TFL::LeakyReluOp) +Leaky Relu operator + +### Description: + +Element-wise Leaky ReLU operator + x -> x >= 0 ? x : (alpha * x) + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `alpha` | `FloatAttr` | 32-bit float attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.less_equal (TFL::LessEqualOp) +Less_equal operator + +### Description: + +Element-wise less_equal operation. + +### Operands: +1. `lhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values +1. `rhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.less (TFL::LessOp) +Less operator + +### Description: + +Element-wise less operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.log (TFL::LogOp) +Natural logarithm operator + +### Description: + +Performs element-wise natural logarithm operation on input. + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.log_softmax (TFL::LogSoftmaxOp) +Log softmax operator + +### Description: + +Computes element-wise log softmax activations with the following formula + + input - log(reduce_sum(exp(input), dim)) + +### Operands: +1. `input`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.logical_and (TFL::LogicalAndOp) +Logical AND operator + +### Description: + +Element-wise logical AND operation. + +### Operands: +1. `lhs`: tensor of 1-bit integer values +1. `rhs`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.logical_not (TFL::LogicalNotOp) +Logical NOT operator + +### Description: + +Element-wise logical NOT operation. + +### Operands: +1. `lhs`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.logical_or (TFL::LogicalOrOp) +Logical OR operator + +### Description: + +Element-wise logical OR operation. + +### Operands: +1. `lhs`: tensor of 1-bit integer values +1. `rhs`: tensor of 1-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.logistic (TFL::LogisticOp) +Logistic operator + +### Description: + +Computes element-wise Sigmoid of input + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of floating-point values + +## tfl.max_pool_2d (TFL::MaxPool2DOp) +Max Pool 2D op + +### Description: + +Performs max pool 2D on input. + +Inputs: + `inputs[0]`: required: the input tensor + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `padding` | `StringAttr` | padding enum attribute | +| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | +| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | +| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute | +| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.maximum (TFL::MaximumOp) +Max operator + +### Description: + +Element-wise max operation. + +### Operands: +1. `lhs`: tensor of floating-point or 32/64-bit integer values +1. `rhs`: tensor of floating-point or 32/64-bit integer values + +### Attributes: + +### Results: +1. `max`: tensor of floating-point or 32/64-bit integer values + +## tfl.mean (TFL::MeanOp) +Mean operator + +### Description: + +Computes the mean of elements across dimensions of a tensor. +Reduces input_tensor along the dimensions given in axis. +Unless keepdims is true, the rank of the tensor is reduced by 1 for +each entry in axis. If keepdims is true, the reduced dimensions are retained +with length 1. + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `axis`: tensor of 32-bit integer or 64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values + +## tfl.minimum (TFL::MinimumOp) +Min operator + +### Description: + +Element-wise min operation. + +### Operands: +1. `lhs`: tensor of floating-point or 32/64-bit integer values +1. `rhs`: tensor of floating-point or 32/64-bit integer values + +### Attributes: + +### Results: +1. `min`: tensor of floating-point or 32/64-bit integer values + +## tfl.mul (TFL::MulOp) +Multiplication operator + +### Description: + +Element-wise multiplication operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.neg (TFL::NegOp) +Negation operator + +### Description: + +Computes element-wise negation of input + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.not_equal (TFL::NotEqualOp) +Not_equal operator + +### Description: + +Element-wise not_equal operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of 1-bit integer values + +## tfl.pack (TFL::PackOp) +Packs a list of tensors along a dimension into one tensor + +### Description: + +Packs a list of `values_count` rank-`R` tensors into one rank-`(R+1)` +tensor. + +Packs the `values_count` tensors in `values` into a tensor with rank one +higher than each tensor in `values`, by packing them along the `axis` +dimension. + +Given a list of tensors of shape `(A, B, C)`; + +if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. +if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. +Etc. + +For example: + +``` +# 'x' is [1, 4] +# 'y' is [2, 5] +# 'z' is [3, 6] +pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] +``` + +This is the opposite of `unpack`. + +### Operands: +1. `values`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `values_count` | `IntegerAttr` | 32-bit integer attribute attribute | +| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values + +## tfl.pad (TFL::PadOp) +Padding operator + +### Description: + +This operation pads a `input` with zeros according to the `paddings` you +specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is +the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` +indicates how many zeros to add before the contents of `input` in that +dimension, and `paddings[D, 1]` indicates how many zeros to add after the +contents of `input` in that dimension. + +The padded size of each dimension D of the output is: + + `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + +For example: + +``` +# 't' is [[1, 1], [2, 2]] +# 'paddings' is [[1, 1], [2, 2]] +# rank of 't' is 2 +pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + [0, 0, 1, 1, 0, 0] + [0, 0, 2, 2, 0, 0] + [0, 0, 0, 0, 0, 0]] + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `padding`: tensor of 32/64-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values + +## tfl.padv2 (TFL::PadV2Op) +Padding operator v2 + +### Description: + +This operation pads a `input` according to the `paddings` and +`constant_values` you specify. `paddings` is an integer tensor with shape +`[Dn, 2]`, where n is the rank of `input`. For each dimension D of `input`, +`paddings[D, 0]` indicates how many zeros to add before the contents of +`input` in that dimension, and `paddings[D, 1]` indicates how many zeros to +add after the contents of `input` in that dimension. `constant_values` is a +scalar tensor of the same type as `input` that indicates the value to use +for padding `input`. + +The padded size of each dimension D of the output is: + + `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + +For example: + +``` +# 't' is [[1, 1], [2, 2]] +# 'paddings' is [[1, 1], [2, 2]] +# rank of 't' is 2 +pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + [0, 0, 1, 1, 0, 0] + [0, 0, 2, 2, 0, 0] + [0, 0, 0, 0, 0, 0]] + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `padding`: tensor of 32/64-bit integer values +1. `constant_values`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values + +## tfl.pow (TFL::PowOp) +Power operator + +### Description: + +Element-wise power operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.pseudo_qconst (TFL::QConstOp) +Quantized constant pseudo op + +### Description: + +Represents a quantized constant value in TensorFlow Lite dialect. This is +not an actual operation and it will be lowered to buffer instead. The +quantization parameters are stored as a type attribute in this constant. + +### Operands: + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `qtype` | `TypeAttr` | Tensor type attribute attribute | +| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.quantize (TFL::QuantizeOp) +Quantize operator + +### Description: + +Converts floating point tensors to quantized integer tensors according to +the quantization parameters defined in the type attribute. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `qtype` | `TypeAttr` | Tensor type attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.range (TFL::RangeOp) +Range operator + +### Description: + +Returns a 1D tensor defined by a sequence from `start` to `limit` with +a given `delta`. + +### Operands: +1. `start`: tensor of any type values +1. `limit`: tensor of any type values +1. `delta`: tensor of any type values + +### Attributes: + +### Results: +1. `result`: tensor of any type values + +## tfl.rank (TFL::RankOp) +Rank operator. + +### Description: + +Returns the rank of a tensor. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any integer type + +## tfl.reduce_max (TFL::ReduceMaxOp) +Max-reduction operator + +### Description: + +Computes the max reduction along the specified axes + +### Operands: +1. `input`: tensor of any type values +1. `axes`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | + +### Results: +1. «unnamed»: tensor of any type values + +## tfl.reduce_min (TFL::ReduceMinOp) +Min-reduction operator + +### Description: + +Computes the min reduction along the specified axes + +### Operands: +1. `input`: tensor of any type values +1. `axes`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | + +### Results: +1. «unnamed»: tensor of any type values + +## tfl.relu6 (TFL::Relu6Op) +Relu6 operator + +### Description: + +Element-wise Relu6 operator + x -> max(0, min(6, x)) + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.relu (TFL::ReluOp) +Relu operator + +### Description: + +Element-wise Relu operator + x -> max(0, x) + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.reshape (TFL::ReshapeOp) +Reshape operator + +### Description: + +Produces a tensor with the same values but different static shape defined +by the output type. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `new_shape` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.resize_bilinear (TFL::ResizeBilinearOp) +ResizeBilinear Op + +### Description: + +Resize `images` to `size` using bilinear interpolation. + +### Operands: +1. `input`: tensor of 32-bit float or 32-bit integer values +1. `size`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `align_corners` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float values + +## tfl.reverse_v2 (TFL::ReverseV2Op) +ReverseV2 Operator + +### Description: + +Reverses specific dimensions of a tensor. + +Given a tensor, and a int32/int64 tensor axis representing the set +of dimensions of tensor to reverse. +This operation reverses each dimension i for +which there exists j s.t. axis[j] == i. + +Args: + tensor: A Tensor. Must be one of the following types: + int16, int32, int64, float32 Up to 8-D. + + axis: A Tensor. Must be one of the following types: int32, int64. + with only 1 element which is the axis index. + TODO: Add support for multiple elements. + +### Operands: +1. `input`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values +1. `axis`: tensor of 32-bit integer or 64-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer values + +## tfl.rsqrt (TFL::RsqrtOp) +Reciprocal of square root operator + +### Description: + +Computes element-wise reverse square root of input + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.select (TFL::SelectOp) +Select operator + +### Description: + +Select values of 'x' if the corresponding value of 'condition' is true or +the value of 'y' if false. There are valid condition input sizes: + +1. Either the same shape (in which case the select is elementwise), or +2. condition must be Rank 1 and match over the first dimension. + +### Operands: +1. `condition`: tensor of 1-bit integer values +1. `x`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values +1. `y`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.shape (TFL::ShapeOp) +Shape operator + +### Description: + +Returns the shape of a tensor. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `out_type` | `Attribute` | derived attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.sin (TFL::SinOp) +Sine operator + +### Description: + +Computes element-wise Sine of input + +### Operands: +1. `x`: tensor of floating-point values + +### Attributes: + +### Results: +1. `y`: tensor of floating-point values + +## tfl.softmax (TFL::SoftmaxOp) +Softmax operator + +### Description: + +Computes element-wise softmax activiations with the following formula + + exp(input) / tf.reduce_sum(exp(input * beta), dim) + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `beta` | `FloatAttr` | 32-bit float attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.space_to_batch_nd (TFL::SpaceToBatchNdOp) +SpaceToBatchNd operator + +### Description: + +This operation reshapes space dimensions into the "batch" dimension 0 + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `block_shape`: tensor of 32-bit integer values +1. `paddings`: tensor of 32-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values + +## tfl.split (TFL::SplitOp) +Splits a tensor into `num_split` tensors along one dimension. + +### Description: + +Splits the `value` tensor along `split_dim` into a number of sub-tensors +with same shape as the original one, except for `split_dim`. Same as +tf.Split. + +### Operands: +1. `split_dim`: tensor of 32-bit integer values +1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values + +## tfl.split_v (TFL::SplitVOp) +Splits a tensor into `num_split` tensors along one dimension. + +### Description: + +Splits the `value` tensor along `split_dim` into a number of sub-tensors +with same shape as the original one, except for `split_dim`. The grouping +of the resultant sub-tensors is decided by `size-splits`. Same as tf.SplitV. + +### Operands: +1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values +1. `size_splits`: tensor of 32-bit integer values +1. `split_dim`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values + +## tfl.sqrt (TFL::SqrtOp) +Square root operator + +### Description: + +Computes element-wise Square root of input + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.square (TFL::SquareOp) +Square operator + +### Description: + +Computes element-wise Square of input + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.squared_difference (TFL::SquaredDifferenceOp) +Squared difference operator + +### Description: + +Element-wise squared difference operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.squeeze (TFL::SqueezeOp) +Removes dimensions of size 1 from the shape of a tensor. + +### Description: + +Given a tensor `input`, this operation returns a tensor of the same type with +all dimensions of size 1 removed. If you don't want to remove all size 1 +dimensions, you can remove specific size 1 dimensions by specifying +`axis`. + +For example: + +``` +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t)) ==> [2, 3] +``` + +Or, to remove specific size 1 dimensions: + +``` +# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +``` + +### Operands: +1. `input`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.strided_slice (TFL::StridedSliceOp) +StridedSlice Op + +### Description: + +Return a strided slice from `input`. + +### Operands: +1. `input`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values +1. `begin`: tensor of 32-bit integer values +1. `end`: tensor of 32-bit integer values +1. `strides`: tensor of 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `begin_mask` | `IntegerAttr` | 32-bit integer attribute attribute | +| `end_mask` | `IntegerAttr` | 32-bit integer attribute attribute | +| `ellipsis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | +| `new_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | +| `shrink_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values + +## tfl.sub (TFL::SubOp) +Subtraction operator + +### Description: + +Element-wise subtraction operation. + +### Operands: +1. `lhs`: tensor of any type values +1. `rhs`: tensor of any type values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.sum (TFL::SumOp) +Sum operator + +### Description: + +Computes the sum reduction along the specified axes + +### Operands: +1. `input`: tensor of any type values +1. `axes`: tensor of 32/64-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `keep_dims` | `BoolAttr` | bool attribute attribute | + +### Results: +1. «unnamed»: tensor of any type values + +## tfl.tanh (TFL::TanhOp) +Hyperbolic tangent operator + +### Description: + +Computes element-wise Hyperbolic tangent of input + +### Operands: +1. `x`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.tile (TFL::TileOp) +Tile operator. + +### Description: + + Constructs a tensor by tiling a given tensor. + +This operation creates a new tensor by replicating input +multiples times. The output tensor's i'th dimension has +input.dims(i) * multiples[i] elements, and the values of input +are replicated multiples[i] times along the 'i'th dimension. +For example, tiling [a b c d] by [2] produces [a b c d a b c d]. + +### Operands: +1. `input`: tensor of any type values +1. `multiples`: tensor of 32/64-bit integer values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + +## tfl.topk_v2 (TFL::TopKV2Op) +TopK operator + +### Description: + +Returns the top `k` largest element along each last dimensional slice of +`input` and the indices of values within the last dimension of the input +tensor. + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values +1. `k`: tensor of 32-bit integer values + +### Attributes: + +### Results: +1. `values`: tensor of any type values +1. `indices`: tensor of 32-bit integer values + +## tfl.transpose (TFL::TransposeOp) +Transpose operator + +### Description: + +Returns the Transpose of x + +### Operands: +1. `x`: tensor of any type values +1. `perm`: tensor of any type values + +### Attributes: + +### Results: +1. `y`: tensor of any type values + +## tfl.unidirectional_sequence_lstm (TFL::UnidirectionalSequenceLSTMOp) +Unidirectional sequence lstm operator + +### Description: + +A recurrent neural network specified by an LSTM cell. This Op supports +unrolling the input along the time or batch dimensions, and +implements the following operation for +each element in the sequence s = 1...sequence_length: + outputs[s] = state = activation(LSTMOp(inputs[s])) + +where LSTMOp is LSTM TF Lite Op and the “activation” is the function passed +as the “fused_activation_function” argument (if not “NONE”). + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer values +1. `input_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `input_to_forget_weights`: tensor of 32-bit float or 8-bit integer values +1. `input_to_cell_weights`: tensor of 32-bit float or 8-bit integer values +1. `input_to_output_weights`: tensor of 32-bit float or 8-bit integer values +1. `recurrent_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `recurrent_to_forget_weights`: tensor of 32-bit float or 8-bit integer values +1. `recurrent_to_cell_weights`: tensor of 32-bit float or 8-bit integer values +1. `recurrent_to_output_weights`: tensor of 32-bit float or 8-bit integer values +1. `cell_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `cell_to_forget_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `cell_to_output_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `input_gate_bias`: tensor of 32-bit float values or none type +1. `forget_gate_bias`: tensor of 32-bit float values +1. `cell_bias`: tensor of 32-bit float values +1. `output_gate_bias`: tensor of 32-bit float values +1. `projection_weights`: tensor of 32-bit float or 8-bit integer values or none type +1. `projection_bias`: tensor of 32-bit float values or none type +1. `input_activation_state`: stateful tensor +1. `input_cell_state`: stateful tensor +1. `input_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type +1. `forget_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type +1. `cell_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type +1. `output_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `fused_activation_function` | `StringAttr` | fused activation enum attribute | +| `cell_clip` | `FloatAttr` | 32-bit float attribute attribute | +| `proj_clip` | `FloatAttr` | 32-bit float attribute attribute | +| `time_major` | `BoolAttr` | bool attribute attribute | + +### Results: +1. `output`: tensor of any type values + +## tfl.unpack (TFL::UnpackOp) +Unpacks a tensor along a dimension into multiple tensors + +### Description: + +Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. + +Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +For example, given a tensor of shape `(A, B, C, D)`; + +If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` + and each tensor in `output` will have shape `(B, C, D)`. (Note that the + dimension unpacked along is gone, unlike `split`). + +If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` + and each tensor in `output` will have shape `(A, C, D)`. +Etc. + +This is the opposite of `pack`. + +### Operands: +1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer values + +### Attributes: +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `num` | `IntegerAttr` | 32-bit integer attribute attribute | +| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | + +### Results: +1. `outputs`: tensor of 32-bit float or 8-bit integer or 32-bit integer values + +## tfl.zeros_like (TFL::ZerosLikeOp) +ZerosLike operator + +### Description: + +Returns a tensor of zeros with the same shape and type as the input tensor. + +### Operands: +1. `input`: tensor of any type values + +### Attributes: + +### Results: +1. `output`: tensor of any type values + diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 99740515a48..c4a3275d557 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") load( "@local_config_mlir//:tblgen.bzl", "gentbl", @@ -185,6 +185,41 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lstm_utils", + srcs = [ + "utils/lstm_utils.cc", + ], + hdrs = [ + "utils/lstm_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + ], +) + +tf_cc_test( + name = "lstm_utils_test", + size = "small", + srcs = ["utils/lstm_utils_test.cc"], + deps = [ + ":lstm_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + ], +) + cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ @@ -198,9 +233,11 @@ cc_library( "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", "transforms/trim_functions_tf.cc", + "transforms/unroll_batch_matmul.cc", ], hdrs = [ "transforms/passes.h", + "transforms/unroll_batch_matmul.h", ], deps = [ ":common", @@ -249,6 +286,7 @@ cc_library( name = "tensorflow_lite_quantize", srcs = [ "transforms/generated_quantize.inc", + "transforms/load_quantization_recipe.cc", "transforms/post_quantize.cc", "transforms/prepare_quantize.cc", "transforms/quantize.cc", @@ -521,7 +559,7 @@ cc_library( ":tensorflow_lite_quantize", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index aa57ff7f751..783696ecac3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -99,7 +99,10 @@ using xla::StatusOr; template using BufferOffset = flatbuffers::Offset; -using CustomOptionsOffset = BufferOffset>; +template +using VectorBufferOffset = flatbuffers::Offset>; + +using CustomOptionsOffset = VectorBufferOffset; namespace error = tensorflow::error; namespace tfl = mlir::TFL; @@ -415,6 +418,15 @@ class Translator { Optional> BuildSubGraph(FuncOp fn); + // Builds Metadata with the given `name` and buffer `content`. + BufferOffset BuildMetadata(StringRef name, + StringRef content); + + // Encodes the `tfl.metadata` dictionary attribute of the module to the + // metadata section in the final model. + Optional>> + CreateMetadataVector(); + // Uses the tf.entry_function attribute (if set) to initialize the op to name // mapping. void InitializeNamesFromAttribute(FuncOp fn); @@ -977,6 +989,36 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { /*name=*/builder_.CreateString(fn.getName().str())); } +BufferOffset Translator::BuildMetadata(StringRef name, + StringRef content) { + auto buffer_index = buffers_.size(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(content.data()), content.size()); + buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); + return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); +} + +Optional>> +Translator::CreateMetadataVector() { + auto dict_attr = module_.getAttrOfType("tfl.metadata"); + if (!dict_attr) return VectorBufferOffset>(); + + std::vector> metadata; + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } + } + return builder_.CreateVector(metadata); +} + Optional Translator::Translate(ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, @@ -1024,12 +1066,17 @@ Optional Translator::TranslateInternal() { } else { model_description = "MLIR Converted."; } + // Build the model and finish the model building process. auto description = builder_.CreateString(model_description.data()); + VectorBufferOffset metadata_buffer = 0; // Deprecated + auto metadata = CreateMetadataVector(); + if (!metadata) return llvm::None; + auto model = tflite::CreateModel( builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_)); + builder_.CreateVector(buffers_), metadata_buffer, *metadata); tflite::FinishModelBuffer(builder_, model); // Return serialized string for the built FlatBuffer. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index e92b7ac60a3..4f3d71a7fd4 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include #include #include "llvm/ADT/APFloat.h" @@ -30,6 +31,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -1167,6 +1169,54 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { return DenseElementsAttr::get(result_type, new_values); } +static LogicalResult Verify(TransposeOp op) { + auto input_type = op.x()->getType().cast(); + auto perm_type = op.perm()->getType().cast(); + auto output_type = op.y()->getType().cast(); + if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { + if (perm_type.getNumElements() != input_type.getRank()) { + return op.emitOpError( + "perm tensor elements size is not equal to input tensor rank"); + } + } + + DenseIntElementsAttr perm; + if (!matchPattern(op.perm(), m_Constant(&perm))) { + return success(); + } + + int index = 0; + llvm::SmallVector axes; + for (auto axis_int : perm.getValues()) { + const int64_t axis = axis_int.getSExtValue(); + if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) { + return op.emitOpError( + llvm::formatv("perm[{0}] must be in [0, rank)", index)); + } + if (std::count(axes.begin(), axes.end(), axis) > 0) { + return op.emitOpError( + llvm::formatv("perm[{0}] cannot have duplicated axis", index)); + } + axes.push_back(axis); + index++; + } + + if (input_type.hasStaticShape() && output_type.hasStaticShape()) { + llvm::SmallVector transposed_shape; + for (int64_t axis : axes) { + transposed_shape.push_back(input_type.getDimSize(axis)); + } + auto expected_output_type = + RankedTensorType::get(transposed_shape, input_type.getElementType()); + if (output_type != expected_output_type) { + return op.emitOpError(llvm::formatv("expect output type {0}, got {1}", + expected_output_type, output_type)); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 1d7b909f762..018e6605197 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -132,15 +132,35 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>; // Rank/Shape helpers. //===----------------------------------------------------------------------===// +class TFL_OperandIsUnrankedPred : + CPred<"$_op.getOperand(" # n # ")->getType().isa()">; + // TODO: Some of these could be generalized and/or moved to more general // location. // Returns true if the n-th operand has unknown rank or has rank m. class TFL_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", - Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa()">, + Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # ")->getType().cast().getRank() == " # m>]>>; +// CPred version of TFL_OperandHasRank. +class TFL_OperandHasRankPred : + Or<[TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ")->getType().cast().getRank() == " # m>]>; + +// True if operand n is ranked and has a rank > dim. +class TFL_OperandIsRankedAndHasDimPred : And<[ + CPred<"$_op.getOperand(" # n # ")->getType().isa()">, + CPred<"$_op.getOperand(" # n # ")->getType().cast().getRank() > " + # dim>]>; + +class TFL_OperandDimEquals : And<[ + TFL_OperandIsRankedAndHasDimPred, + CPred<"$_op.getOperand(" # n # ")->getType().cast()" + ".getShape()[" # dim # " ] == " # size>]>; + // Returns true if the n-th operand has unknown rank or at least rank m. class TFL_OperandHasAtleastRank : PredOpTrait<"operand " # n # " is " # m # "-D", @@ -155,6 +175,32 @@ class TFL_OperandRankEquals1DimOfOperand : "$_op.getOperand(" # y # ")->getType().cast().getShape()[0]">>; +// True if x_shape[dim] == y_shape[dim]. +class TFL_DimOfOperandEqualsDimOfOperandPred : And<[ + TFL_OperandIsRankedAndHasDimPred, + TFL_OperandIsRankedAndHasDimPred, + CPred<"$_op.getOperand(" # x # + ")->getType().cast().getShape()[" # dim # "] == " + "$_op.getOperand(" # y # + ")->getType().cast().getShape()[" # dim # "]">]>; + +// Select operands must satisfy one of the following constraints: +// All inputs are unranked/scalars +// OR +// All inputs are ranked AND have equal dim[0] AND X & Y have same rank. +def SelectShapeConstraints : + PredOpTrait<"Select operands meet shape criteria", + Or<[ + And<[ + TFL_OperandHasRankPred<0, 0>, + TFL_OperandHasRankPred<1, 0>, + TFL_OperandHasRankPred<2, 0>]>, + And<[ + TFL_DimOfOperandEqualsDimOfOperandPred<0, 1, 0>, + TFL_DimOfOperandEqualsDimOfOperandPred<0, 2, 0>, + CPred<"$_op.getOperand(1)->getType().cast().getRank() == " + "$_op.getOperand(2)->getType().cast().getRank()">]>]>>; + // This is a quantization-aware version of TCresVTEtIsSameAsOp class TFL_TCresVTEtIsSameAsOp : And<[ TCOpResIsShapedTypePred, @@ -315,7 +361,7 @@ def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> { // TODO(haoliang): Implement legalization pass after pattern rewrite generator // supports variadic inputs. -def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> { +def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> { let summary = "add_n operator"; let description = [{ @@ -323,11 +369,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> { }]; let arguments = (ins - Variadic>:$inputs + Variadic>:$inputs ); let results = (outs - TensorOf<[F32, I32]>:$sum + TensorOf<[F32, I32, QI16, QUI16]>:$sum ); } @@ -680,6 +726,117 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ let hasOptions = 0; } +// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be +// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the +// sense that one is an incremental change over the other. +// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and +// NonMaxSuppressionV4 performs hard NMS. + +def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [ + NoSideEffect, + // Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners) + TFL_OperandHasRank<0, 2>, + PredOpTrait<"boxes should have dim[1] == 4", + TFL_OperandDimEquals<0, 1, 4>>, + // Operand 1 (scores) should be a 1-dim tensor + TFL_OperandHasRank<1, 1>, + // Other operands are scalar params. + TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, + TFL_OperandHasRank<4, 0>]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + }]; + + let arguments = (ins + TFL_FpTensor:$boxes, + TFL_FpTensor:$scores, + I32Tensor:$max_output_size, + TFL_FpTensor:$iou_threshold, + TFL_FpTensor:$score_threshold + ); + + let results = (outs + I32Tensor:$selected_indices, + I32Tensor:$valid_outputs + ); +} + +def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [ + NoSideEffect, + // Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners) + TFL_OperandHasRank<0, 2>, + PredOpTrait<"boxes should have dim[1] == 4", + TFL_OperandDimEquals<0, 1, 4>>, + // Operand 1 (scores) should be a 1-dim tensor + TFL_OperandHasRank<1, 1>, + // Other operands are scalar params. + TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, + TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) +This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f. +Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score +of other overlapping boxes instead of directly causing them to be pruned. +To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be +larger than 0. + }]; + + let arguments = (ins + TFL_FpTensor:$boxes, + TFL_FpTensor:$scores, + I32Tensor:$max_output_size, + TFL_FpTensor:$iou_threshold, + TFL_FpTensor:$score_threshold, + TFL_FpTensor:$soft_nms_sigma + ); + + let results = (outs + I32Tensor:$selected_indices, + TFL_FpTensor:$selected_scores, + I32Tensor:$valid_outputs + ); +} + def TFL_NotEqualOp : TFL_Op<"not_equal", [ Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> { let summary = "Not_equal operator"; @@ -987,11 +1144,11 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, QUI8, QI8, I8]>:$input, + TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input, TFL_AFAttr:$fused_activation_function ); - let results = (outs TensorOf<[F32, QUI8, QI8, I8]>:$output); + let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output); let hasOptions = 1; @@ -1100,9 +1257,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ Computes element-wise Sigmoid of input }]; - let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x); + let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x); - let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y); + let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y); } def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { @@ -1441,7 +1598,7 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> { let hasOptions = 0b1; } -def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> { +def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Packs a list of tensors along a dimension into one tensor"; let description = [{ @@ -1472,14 +1629,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> { }]; let arguments = (ins - Variadic>:$values, + Variadic>:$values, I32Attr:$values_count, I32Attr:$axis ); let results = (outs - TensorOf<[F32, I8, I16, I32, I64]>:$output + TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -1777,8 +1934,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", } def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, - // TODO(jpienaar): This is too retrictive, rank 1 input is also allowed. - SameOperandsAndResultShape, + SelectShapeConstraints, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands and result have same element type", TCresVTEtIsSameAsOp<0, 1>>]> { @@ -1836,7 +1992,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ let summary = "Softmax operator"; let description = [{ - Computes element-wise softmax activiations with the following formula + Computes element-wise softmax activations with the following formula exp(input) / tf.reduce_sum(exp(input * beta), dim) }]; @@ -1942,9 +2098,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [ Computes element-wise Hyperbolic tangent of input }]; - let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$x); + let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); - let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$y); + let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); } def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, @@ -1999,8 +2155,6 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, let hasOptions = 1; } -// TODO: Verify result shape a permutation of the first input shape's -// dimensions. def TFL_TransposeOp : TFL_Op<"transpose", [NoSideEffect, TFL_OperandHasRank<1,1>, @@ -2025,6 +2179,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", AnyTensor:$y ); + let verifier = [{ return Verify(*this); }]; + let hasFolder = 1; } @@ -2342,7 +2498,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", let hasOptions = 1; } -def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { +def TFL_CastOp : TFL_Op<"cast", [ + NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> { let summary = "Cast operator"; let description = [{ @@ -2629,6 +2786,10 @@ Ba et al. “Layer Normalization” let results = (outs AnyTensor:$output); + // TODO(fengliuai): customize printer and parser to not display + // empty region. + let regions = (region AnyRegion:$internal); + let hasOptions = 1; let verifier = [{ return Verify(*this); }]; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index e4158700713..4e3fda7771e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -549,32 +549,42 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( void QuantizationDriver::PreprocessConstantOps() { fn_.walk([&](ConstantOp cst) { - // Non-float tensors are neither weights or require quantization. - if (!cst.getType().cast().getElementType().isa()) { - return; - } + // Non-float tensors are neither weights nor require quantization. + auto type = cst.getType().dyn_cast(); + if (!type || !type.getElementType().isa()) return; Value *value = cst.getResult(); SmallVector, 4> bias_users; + bool used_as_weight = false; for (auto &use : value->getUses()) { auto spec = GetQuantSpec(use.getOwner()); auto biases = spec->biases_params; Operation *user = use.getOwner(); int operand_num = use.getOperandNumber(); - // The user doesn't use this value as a bias operand nor require same - // scale. + // The user doesn't use this value as a bias operand or require same + // scale, then this constant is considered to be a weight. if (biases.find(operand_num) == biases.end() && !spec->requires_same_scale) { - weights_.insert(cst); + used_as_weight = true; } else { bias_users.push_back({user, operand_num}); } } - builder_.setInsertionPoint(cst); - for (int i = 1; i < bias_users.size(); ++i) { + + // If the constant is used as a weight, this constant will be duplicated for + // each bias user, so it isn't shared with the weight usage. Otherwise, the + // first bias user can use the original constant and the rest use the + // duplications, so we pop bias user from the set. + if (used_as_weight) { + weights_.insert(cst); + } else { + bias_users.pop_back(); + builder_.setInsertionPoint(cst); + } + for (auto bias_user : bias_users) { auto copied = builder_.create(cst.getLoc(), cst.getValue()); - bias_users[i].first->setOperand(bias_users[i].second, copied.getResult()); + bias_user.first->setOperand(bias_user.second, copied.getResult()); } }); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 4919fbc74fe..c6355e123f1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -32,7 +32,7 @@ static Type GetQuantizedType(Builder builder, Type input_type, double min, double max, int storage_type_width, bool narrow_range, bool is_signed) { auto converter = - quant::ExpressedToUniformQuantizedConverter::forInputType(input_type); + quant::ExpressedToQuantizedConverter::forInputType(input_type); quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType( builder.getUnknownLoc(), storage_type_width, min, max, narrow_range, diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir index 5cbcb1e1cb8..a71e5cfae24 100644 --- a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir +++ b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir @@ -13,6 +13,11 @@ func @extractSimpleOphint() { return } +// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"} + +// ----- + // CHECK-LABEL: extractPackedInputOphint func @extractPackedInputOphint() { // CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> @@ -30,6 +35,11 @@ func @extractPackedInputOphint() { return } +// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"} + +// ----- + // CHECK-LABEL: extractFirstInputOphint func @extractFirstInputOphint() { // CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> @@ -46,6 +56,11 @@ func @extractFirstInputOphint() { return } +// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"} + +// ----- + // CHECK-LABEL: extractLastInputOphint func @extractLastInputOphint() { // CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> @@ -62,6 +77,11 @@ func @extractLastInputOphint() { return } +// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"} + +// ----- + // CHECK-LABEL: extractPackOneInputOphint func @extractPackOneInputOphint() { // CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32> @@ -75,13 +95,16 @@ func @extractPackOneInputOphint() { return } +// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"} + +// ----- + // CHECK-LABEL: extractStackInputOutputOphint func @extractStackInputOutputOphint() { // CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> @@ -98,11 +121,14 @@ func @extractStackInputOutputOphint() { return } +// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> +// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"} + +// ----- + // CHECK-LABEL: extractMultipleInputsOutputsOphint func @extractMultipleInputsOutputsOphint() { -// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> @@ -119,21 +145,33 @@ func @extractMultipleInputsOutputsOphint() { return } -// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation"} -// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation_stack"} -// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation_first"} -// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation_last"} -// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"} -// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> -// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"} // CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"} +// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"} +// ----- + +// CHECK-LABEL: inputsAfterOutputs +func @inputsAfterOutputs() { +// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> +// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) + + %0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor} : () -> tensor + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor) -> tensor + %2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> + %3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> + %7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return +} + +// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) +// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"} // ----- diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 5d265305796..45853817aec 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -50,7 +50,7 @@ func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor) -> i32 { %0 = "tf.Squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32> %1 = "tf.Squeeze"(%arg1) : (tensor) -> tensor<*xf32> - %2 = constant dense<[2, 5]> : tensor<2xi32> + %2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32> %3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32> %4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 return %4 : i32 @@ -119,8 +119,8 @@ func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { } func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { - %arg1 = constant dense<-0.1> : tensor - %arg2 = constant dense<0.2> : tensor + %arg1 = "tf.Const"() { value = dense<-0.1> : tensor } : () -> tensor + %arg2 = "tf.Const"() { value = dense<0.2> : tensor } : () -> tensor %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor, tensor) -> tensor<8x8x8x8xf32> return %0 : tensor<8x8x8x8xf32> @@ -153,6 +153,14 @@ func @placeholder(%arg0: tensor) -> tensor { // CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor) -> tensor } +func @placeholder_int(%arg0: tensor) -> tensor { + %0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor) -> tensor + return %0: tensor + +// CHECK-LABEL: @placeholder_int +// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor) -> tensor +} + func @placeholder_min(%arg0: tensor) -> tensor { %0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor) -> tensor return %0: tensor @@ -409,7 +417,7 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3 } func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> { - %0 = constant dense<[1]> : tensor<1xi32> + %0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32> return %1 : tensor<1x3x5x20xf32> @@ -418,7 +426,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) } func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> { - %0 = constant dense<[-1]> : tensor<1xi32> + %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32> return %1 : tensor<1x2x3x5xf32> @@ -427,7 +435,7 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x } func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> { - %0 = constant dense<[1]> : tensor<1xi32> + %0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32> return %1 : tensor<1x2x3x5xf32> @@ -509,6 +517,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> // CHECK: return %0 : tensor<8xf32> } +func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> { + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32> + return %0: tensor<8x3xf32> + +// CHECK-LABEL: select_multidim +// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2) +// CHECK: return %0 : tensor<8x3xf32> +} + func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> return %0: tensor<8xf32> @@ -518,6 +535,15 @@ func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) // CHECK: return %0 : tensor<8xf32> } +func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32> + return %0: tensor<8x3xf32> + +// CHECK-LABEL: select_v2_multidim +// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2) +// CHECK: return %0 : tensor<8x3xf32> +} + func @sin(%arg0: tensor) -> tensor { %0 = "tf.Sin"(%arg0) : (tensor) -> tensor return %0 : tensor @@ -536,7 +562,7 @@ func @topk(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor, tensor) -> (tensor<2xf32>, tensor<2xi32>) { - %0 = constant dense<2> : tensor + %0 = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor) -> (tensor<2xf32>, tensor<2xi32>) return %1#0, %1#1: tensor<2xf32>, tensor<2xi32> @@ -546,7 +572,7 @@ func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) { } func @topk_3(%arg0: tensor) -> (tensor, tensor) { - %0 = constant dense<2> : tensor + %0 = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %1:2 = "tf.TopKV2"(%arg0, %0) : (tensor, tensor) -> (tensor, tensor) return %1#0, %1#1: tensor, tensor @@ -556,7 +582,7 @@ func @topk_3(%arg0: tensor) -> (tensor, tensor) { } func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) { - %0 = constant dense<2> : tensor + %0 = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32> @@ -566,7 +592,7 @@ func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2 } func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) { - %0 = constant dense<2> : tensor + %0 = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xi32>) return %1#0, %1#1: tensor<*xf32>, tensor<*xi32> @@ -671,7 +697,7 @@ func @pow(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x1x1xf32>) -> tensor<2x1x3xf3 func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> { ^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>): - %cst = constant dense<[1, 2]> : tensor<2xi32> + %cst = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> } : () -> tensor<2xi32> %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> return %0 : tensor<2x6xf32> @@ -682,7 +708,7 @@ func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> { func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor { ^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>): - %cst = constant dense<2.0> : tensor + %cst = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor %0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor) -> tensor return %0#0 : tensor @@ -858,8 +884,8 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t } func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - %0 = constant dense<[1]> : tensor<1xi32> - %1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + %0 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %1 : tensor<2x2xi32> // CHECK-LABEL: concat2Tensors @@ -867,8 +893,8 @@ func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi } func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> { - %0 = constant dense<[-1]> : tensor<1xi32> - %1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> // CHECK-LABEL: concat3Tensors @@ -876,8 +902,8 @@ func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2 } func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> { - %0 = constant dense<[-1]> : tensor<1xi32> - %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32> + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> // CHECK-LABEL: concatv2With3Tensors @@ -1093,3 +1119,35 @@ func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { // CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32> // CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> } + +func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<2xi32> { + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0#0 : tensor<2xi32> + + // CHECK-LABEL: non_max_suppression_v4 + // CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) +} + +func @non_max_suppression_v4_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<2xi32> { + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0#0 : tensor<2xi32> + + // CHECK-LABEL: non_max_suppression_v4_no_pad + // CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) +} + +func @non_max_suppression_v5(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) -> tensor<2xi32> { + %0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) + return %0#0 : tensor<2xi32> + + // CHECK-LABEL: non_max_suppression_v5 + // CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) +} + +func @non_max_suppression_v5_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) -> tensor<2xi32> { + %0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) + return %0#0 : tensor<2xi32> + + // CHECK-LABEL: non_max_suppression_v5_no_pad + // CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) +} diff --git a/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir b/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir new file mode 100644 index 00000000000..5c53d5e05e7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/load-quantization-recipe.mlir @@ -0,0 +1,107 @@ +// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: testLstm +func @testLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + %0 = "tfl.lstm"(%arg0, // input + %arg1, %arg2, %arg3, %arg4, // weights + %arg5, %arg6, %arg7, %arg8, // recurrent weights + %arg9, %arg10, %arg11, // cell weights + %arg12, %arg13, %arg14, %arg15, // bias + %arg16, %arg17, // projection weight and bias + %arg18, %arg19, // stateful + %arg20, %arg21, %arg22, %arg23 // layer norm coefficients + ) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + +// CHECK-NEXT: "tfl.lstm" +// CHECK-NEXT: %[[cst:.*]] = constant unit + +// input gate +// CHECK-NEXT: %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in3:.*]] = "tfl.mul"(%arg19, %arg9) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in6:.*]] = tfl.add %[[in4]], %[[in5]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[in8:.*]] = "tfl.logistic"(%[[in7]]) +// CHECK-SAME: -> tensor> + +// forget gate +// CHECK-NEXT: %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]]) +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[fo8:.*]] = "tfl.logistic"(%[[fo7]]) +// CHECK-SAME: -> tensor> + +// cell gate +// CHECK-NEXT: %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ce7:.*]] = "tfl.tanh"(%[[ce6]]) +// CHECK-SAME: -> tensor> + +// CHECK-NEXT: %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]] +// CHECK-SAME: tensor> + +// output gate +// CHECK-NEXT: %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ou8:.*]] = "tfl.logistic"(%[[ou7]]) +// CHECK-SAME: -> tensor> + +// output activation +// CHECK-NEXT: %[[ac4:.*]] = "tfl.tanh"(%[[ac3]]) +// CHECK-SAME: -> tensor> +// CHECK-NEXT: %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]] +// CHECK-SAME: tensor> +// CHECK-NEXT: %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17) +// CHECK-SAME: (tensor>, tensor, tensor) -> tensor> +// CHECK-NEXT: %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor>) -> tensor> +// CHECK-NEXT: }) +// CHECK-NEXT: return + + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 817ced79ced..287958e905c 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -143,6 +143,19 @@ func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: t // CHECK: return [[RESULT]] : tensor } +func @tensorlistLength(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>) -> (tensor) { + %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> + %1 = "tf.TensorListLength"(%0) : (tensor>>) -> tensor + return %1: tensor + +// CHECK-LABEL: tensorlistLength +// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>) +// CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32> +// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor +// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor) -> tensor +// CHECK: return [[RESULT]] : tensor +} + func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { %cst = constant dense<3> : tensor<1xi32> %cst_0 = constant dense<0> : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index ddb122f6e37..23976dbb476 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -278,6 +278,6 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32> %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32> %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32> - %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %24 : tensor<4xf32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir new file mode 100644 index 00000000000..e89c2715c50 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir @@ -0,0 +1,31 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s + +module attributes { + tfl.metadata = {key1 = "value1", key2 = "value2"} +} { + func @main(tensor<3x2xi32>) -> tensor<3x2xi32> + attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} { + ^bb0(%arg0: tensor<3x2xi32>): + %0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input") + %1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %2 = "tfl.sub" (%0, %1) {fused_activation_function = "NONE"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> + return %2 : tensor<3x2xi32> + } +} + +// CHECK: buffers: [ { +// CHECK: }, { +// CHECK: }, { +// CHECK: }, { +// CHECK: }, { +// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "key1", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "key2", +// CHECK-NEXT: buffer: 5 +// CHECK-NEXT: } ] diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/optional.mlir index 97129df86a2..d62d0ac2c31 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/optional.mlir @@ -1,5 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - -// | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %cst = constant unit @@ -9,7 +8,7 @@ func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf return %2 : tensor<40x40xf32> } -// CHECK-NEXT: operators: [ { +// CHECK: operators: [ { // CHECK-NEXT: inputs: [ 0, 1, -1 ], // CHECK-NEXT: outputs: [ 2, 3 ], // CHECK-NEXT: builtin_options_type: FullyConnectedOptions, diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index a0d78c25297..3a051678664 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -103,7 +103,7 @@ func @testAddN(tensor, tensor, tensor) -> tensor, tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): - // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}} + // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}} %0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -537,7 +537,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}} + // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -591,8 +591,9 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor, // CHECK-LABEL: testLstm func @testLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) + // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -600,8 +601,9 @@ func @testLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) + // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -610,7 +612,7 @@ func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, %arg1: none, % // test invalid none type applied to a tensor type arg func @testLstmWithInvalidNoneType(%arg0: tensor, %arg1: tensor, %arg2: none, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}} - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -619,7 +621,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor, %arg1: tensor // test violation of projection weight and projection bias pred op trait func @testLstmWithInvalidNoneType(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: none, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}} - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -628,7 +630,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor, %arg1: tensor // test invalid kernel type func @testLstmWithInvalidKernelType(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}} - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -652,6 +654,15 @@ func @testSelect(%cond : tensor, %arg0 : tensor, %arg1 : tensor, %arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "tfl.select"(%cond, %arg0, %arg1): (tensor,tensor,tensor) -> tensor + return %0 : tensor +} + +// ----- + func @testSelectWithUnsupportedType(%cond : tensor, %arg0 : tensor, %arg1 : tensor) -> tensor { // expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}} %0 = "tfl.select"(%cond, %arg0, %arg1): (tensor,tensor,tensor) -> tensor @@ -660,6 +671,14 @@ func @testSelectWithUnsupportedType(%cond : tensor, %arg0 : tensor // ----- +func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> { + // expected-error @+1 {{failed to verify that Select operands meet shape criteria}} + %0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + func @testSelectWithUnsupportedType(%cond : tensor, %arg0 : tensor, %arg1 : tensor) -> tensor { // expected-error @+1 {{failed to verify that operands have same element type}} %0 = "tfl.select"(%cond, %arg0, %arg1): (tensor,tensor,tensor) -> tensor @@ -762,6 +781,21 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te // ----- +// CHECK-LABEL: testPadQuantizedU8 +func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<3x2xi32>) -> tensor> { + // CHECK: "tfl.pad"(%arg0, %arg1) + %0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> tensor> + return %0#0 : tensor> +} + +// CHECK-LABEL: testPadQuantizedI8 +func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<3x2xi32>) -> tensor> { + // CHECK: "tfl.pad"(%arg0, %arg1) + %0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> tensor> + return %0#0 : tensor> +} +// ----- + // CHECK-LABEL: testPadV2 func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor { ^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>): @@ -817,6 +851,20 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> // ----- +func @packQuantizedU8(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +func @packQuantizedI8(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> @@ -1101,6 +1149,63 @@ func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) -> } +// ----- + +func @transpose_perm_size(%arg0 : tensor<2x2xi32>, %arg1 : tensor<3xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{perm tensor elements size is not equal to input tensor rank}} + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<3xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_unranked_shape(%arg0 : tensor<*xi32>) -> tensor<2x2xi32> { + %cst = constant dense<[1, 0]> : tensor<2xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<*xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_dynamic_shape(%arg0 : tensor<2x?xi32>) -> tensor { + %cst = constant dense<[1, 0]> : tensor<2xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x?xi32>, tensor<2xi32>) -> tensor + return %0 : tensor +} + + +// ----- + +func @transpose_perm_axis_invalid(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> { + %cst = constant dense<[1, -1]> : tensor<2xi32> + // expected-error @+1 {{perm[1] must be in [0, rank)}} + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_perm_axis_duplicated(%arg0 : tensor<2x2xi32>) -> tensor<2x2xi32> { + %cst = constant dense<[1, 1]> : tensor<2xi32> + // expected-error @+1 {{perm[1] cannot have duplicated axis}} + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_output_type_bad(%arg0 : tensor<3x4x5x6xi32>) -> tensor<3x4x5x6xi32> { + %cst = constant dense<[0, 3, 1, 2]> : tensor<4xi32> + // expected-error @+1 {{expect output type tensor<3x6x4x5xi32>, got tensor<3x4x5x6xi32>}} + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<3x4x5x6xi32> + return %0 : tensor<3x4x5x6xi32> +} + + // ----- func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> { @@ -1643,3 +1748,33 @@ func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tens return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32> } + +// ----- + +func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor<2xi32>, tensor) { + %0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0, %1 : tensor<2xi32>, tensor +} + +// ----- + +func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor<2xi32>, tensor) { + // expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}} + %0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0, %1 : tensor<2xi32>, tensor +} + +// ----- + +func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) { + %0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) + return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor +} + +// ----- + +func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) { + // expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}} + %0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor, tensor, tensor, tensor) -> (tensor<2xi32>, tensor<2xf32>, tensor) + return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 138962d5fca..f1e556703e3 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -292,16 +292,3 @@ func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> t // CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> // CHECK: return %3 } - -// CHECK-LABEL: @InvalidL2NormalizePatternMorethan1Dimension -// Input has higher rank, it should be limited to 1D only. -func @InvalidL2NormalizePatternMorethan1Dimension(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %cst = constant dense<[0]> : tensor<1xi32> - %0 = "tfl.square"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor - %2 = "tfl.sqrt"(%1) : (tensor) -> tensor - %3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - return %3: tensor<2x2xf32> - // CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - // CHECK: return %3 -} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 5fd57ab21b4..092cb1e52f9 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -373,21 +373,12 @@ func @QuantizeConstant() -> tensor<2x3xf32> { // CHECK: return %1 : tensor<2x3xf32> } -// CHECK-LABEL: NotQuantizeNonZeroSplat -func @NotQuantizeNonZeroSplat() -> tensor<2x3xf32> { - %cst = constant dense<2.0> : tensor<2x3xf32> - return %cst : tensor<2x3xf32> +// CHECK-LABEL: NotQuantizeNoneType +func @NotQuantizeNoneType() -> none { + %cst = constant unit + return %cst : none -// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00> -// CHECK-NEXT: return %[[cst]] -} - -// CHECK-LABEL: NotQuantizeNonZeroScalar -func @NotQuantizeNonZeroScalar() -> tensor { - %cst = constant dense<2.0> : tensor - return %cst : tensor - -// CHECK-NEXT: %[[cst:.*]] = constant dense<2.000000e+00> +// CHECK-NEXT: %[[cst:.*]] = constant unit // CHECK-NEXT: return %[[cst]] } @@ -433,6 +424,32 @@ func @QuantizeSharedBiases( // CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32> // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) -// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]]) +// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) +} + +// CHECK-LABEL: QuantizeSharedBiases2 +func @QuantizeSharedBiases2( + %arg0: tensor<32x!quant.uniform>, + %arg1: tensor<1x112x112x32x!quant.uniform>, + %arg2: tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> (tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform>) { + %cst = constant dense<1.0> : tensor<32xf32> + %1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> + %add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> + + %5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> + %6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> tensor<32x3x3x3xf32> + %conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32> + %7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform> + return %3, %7 : tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform> + +// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) +// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32> +// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform:f32, 1.000000e+00:1>>} +// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) +// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] +// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 6a7d883ce50..8d6d7ab513e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -63,7 +63,7 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8 return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> // CHECK-LABEL: fusedBatchNorm -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03> +// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> // variance + epsilon // CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) // rsqrt(variance + epsilon) @@ -96,7 +96,7 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> // CHECK-LABEL: fusedBatchNormV3 -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<1.000000e-03> +// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> // variance + epsilon // CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) // rsqrt(variance + epsilon) @@ -155,7 +155,7 @@ func @fakeQuantFolded() -> (tensor<8xf32>) { %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor, tensor) -> tensor<8xf32> return %rst : tensor<8xf32> -// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<8xf32> // CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> @@ -262,7 +262,7 @@ func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) return %rst : tensor<256x30x30x16xf32> // CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<16x3x3x3xf32> +// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32> // CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform>} // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) @@ -282,7 +282,7 @@ func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30 return %rst : tensor<256x30x30x16xf32> // CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32> -// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<0.000000e+00> : tensor<1x3x3x48xf32> +// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32> // CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) @@ -348,3 +348,11 @@ func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> { // CHECK-LABEL: stop_gradient // CHECK: return %arg0 : tensor<3xi32> } + +func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> + // Should be converted to Identity and then from Identity to value + // CHECK-LABEL: CheckNumerics + // CHECK: return %arg0 : tensor<3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir new file mode 100644 index 00000000000..09f1dfc9133 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir @@ -0,0 +1,223 @@ +// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s + +func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> + return %0 : tensor<2x3x4x6xf32> + + // CHECK-LABEL: batchMatMulV2TwoDim + // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> + // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> + // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> + + // CHECK: return %[[v33]] : tensor<2x3x4x6xf32> +} + +func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + + // CHECK-LABEL: batchMatMulV2FlatInput + // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> + // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + + // CHECK: return %[[v18]] : tensor<3x4x6xf32> +} + +func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + return %0 : tensor<4x6xf32> + + // CHECK-LABEL: batchMatMulV2Matrix + // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: return %[[v0]] : tensor<4x6xf32> +} + +func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> + return %0 : tensor<2x3x4x6xf32> + + // CHECK-LABEL: batchMatMulTwoDim + // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> + // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> + // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> + + // CHECK: return %[[v33]] : tensor<2x3x4x6xf32> +} + +func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + + // CHECK-LABEL: batchMatMulFlatInput + // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> + // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + + // CHECK: return %[[v18]] : tensor<3x4x6xf32> +} + +func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + return %0 : tensor<4x6xf32> + + // CHECK-LABEL: batchMatMulMatrix + // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: return %[[v0]] : tensor<4x6xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index fc86eb63753..99d88ba9b93 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -149,7 +149,12 @@ int main(int argc, char **argv) { lower_tensor_list_ops, &result, &pm); if (!status.ok()) return kTrFailure; - auto output = mlir::openOutputFile(output_file_name); + std::string error_msg; + auto output = mlir::openOutputFile(output_file_name, &error_msg); + if (output == nullptr) { + llvm::errs() << error_msg << '\n'; + return kTrFailure; + } output->os() << result; output->keep(); diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index b6a898e6cda..29fc88462cb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include #include +#include +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -353,6 +355,127 @@ struct OphintCompositeOp { std::map outputs; }; +// Preprocess the graph for topo sort. (each operation is a node, while +// inputs/outputs indictate edges) Assume the graph is acyclic. The preprocess +// does the following: +// Compute each operations's in-degress (how many input nodes they're taken) +// Get all consumer operations for every operations. (operation_to_ouputs) +// Get the init_queue (those operations will be processed first). +void PreprocessTopoSortGraph( + Block* block, std::queue* init_queue, + llvm::DenseMap>* operation_to_ouputs, + llvm::DenseMap* operation_to_in_degrees) { + for (auto& op : *block) { + if (&op == block->getTerminator()) continue; + if (op.getNumOperands() == 0) { + init_queue->push(&op); + } else { + // The operand of the ops is not a direct indication of the "edge" as we + // can have a pack op after a unpack op (they have multiple edges), we + // should only count as one. + llvm::DenseSet input_ops; + for (int i = 0; i < op.getNumOperands(); ++i) { + Operation* input_op = op.getOperand(i)->getDefiningOp(); + if (input_op) input_ops.insert(input_op); + } + if (input_ops.empty()) { + init_queue->push(&op); + continue; + } + operation_to_in_degrees->try_emplace(&op, input_ops.size()); + for (auto* input_op : input_ops) { + auto preceeding_op_it = operation_to_ouputs->find(input_op); + if (preceeding_op_it == operation_to_ouputs->end()) { + auto result = operation_to_ouputs->try_emplace( + input_op, llvm::DenseSet()); + preceeding_op_it = result.first; + } + preceeding_op_it->second.insert(&op); + } + } + } +} + +bool IsSideEffectOp(Operation* op) { + if (op->hasNoSideEffect()) return false; + + // Identity op has no side effect. + // Check the OperationName maybe more elegant here. + auto tf_identity_op = dyn_cast_or_null(op); + if (tf_identity_op) return false; + return true; +} + +// It's possible other transformations can benefit from this util function, but +// since currently there's none, so we only limit this function to the ophint +// extraction pass. We may refactor this function to extend the usage in future. +// +// Assume the graph is disconnected from outside. +// Also assume the block has no arguments. +LogicalResult TopoSortOperations(OpBuilder* builder) { + std::queue init_queue; + llvm::DenseMap> operation_to_ouputs; + llvm::DenseMap operation_to_in_degrees; + std::vector sorted_ops; + + PreprocessTopoSortGraph(builder->getBlock(), &init_queue, + &operation_to_ouputs, &operation_to_in_degrees); + while (!init_queue.empty()) { + Operation* current_op = init_queue.front(); + init_queue.pop(); + sorted_ops.push_back(current_op); + + auto current_op_to_output_it = operation_to_ouputs.find(current_op); + if (current_op_to_output_it == operation_to_ouputs.end()) { + continue; + } + for (Operation* output_op : current_op_to_output_it->second) { + auto output_op_it = operation_to_in_degrees.find(output_op); + if (output_op_it == operation_to_in_degrees.end()) return failure(); + + output_op_it->second -= 1; + if (output_op_it->second == 0) { + init_queue.push(output_op); + operation_to_in_degrees.erase(output_op_it); + } + } + operation_to_ouputs.erase(current_op_to_output_it); + } + + // Before we performs the sort. We need to make sure we didn't mess the + // ordering of original side-effect operations. + // It's possible those side-effect operations have no topogocial relations + // at all! + std::vector original_side_effect_ops; + std::vector after_sort_side_effect_ops; + for (auto& op : *builder->getBlock()) { + if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator())) + original_side_effect_ops.push_back(&op); + } + for (auto* op : sorted_ops) { + if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op); + } + if (original_side_effect_ops.size() != after_sort_side_effect_ops.size()) + return failure(); + for (int i = 0; i < original_side_effect_ops.size(); ++i) { + if (original_side_effect_ops[i] != after_sort_side_effect_ops[i]) + return failure(); + } + + // Performs the sort. + // Ideally it would be nice to just clear the block then write the sorted ops. + // But unfortunately that's hard to do. + for (int i = sorted_ops.size() - 1; i > 0; --i) { + Operation* current_op = sorted_ops[i]; + for (int j = i - 1; j >= 0; --j) { + Operation* prev_op = sorted_ops[j]; + prev_op->moveBefore(current_op); + } + } + + return success(); +} + Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, Operation* insert_before_op, const std::map& inputs, @@ -360,10 +483,12 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, OpBuilder* builder, ModuleOp* module_op) { SmallVector input_types; SmallVector input_values; + SmallVector input_indexes; for (const auto& kv : inputs) { Value* input = kv.second; input_types.push_back(input->getType()); input_values.push_back(input); + input_indexes.push_back(kv.first); } SmallVector func_output_types; @@ -378,6 +503,8 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, SmallVector attrs; attrs.push_back(builder->getNamedAttr( kTfLiteFunctionName, builder->getStringAttr(fused_func_type))); + attrs.push_back(builder->getNamedAttr( + kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes))); FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name, function_type, llvm::makeArrayRef(attrs)); module_op->push_back(func_op); @@ -507,6 +634,10 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, }; builder->getBlock()->walk(removeRemovableOps); + + // Step 8: Topo sort to fix any invalid temporary IRs. + if (failed(TopoSortOperations(builder))) return failure(); + return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 65c4eb76a77..ec328304d92 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -20,6 +20,10 @@ include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +def NonOpaqueElementsAttr : ElementsAttrBase< + CPred<"!$_self.isa()">, + "non-opaque constant tensor">; + def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; @@ -56,8 +60,13 @@ def ExtractSingleElementAsInteger : NativeCodeCall< //===----------------------------------------------------------------------===// // Nullary ops patterns. //===----------------------------------------------------------------------===// + def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; +// Convert to std constant for statically shaped, non-opaque constants. +def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), + [(AnyStaticShapeTensor $res)], (addBenefit 10)>; + //===----------------------------------------------------------------------===// // Unary ops patterns. //===----------------------------------------------------------------------===// @@ -157,7 +166,8 @@ def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; // The following two rules can both match an tf.Placeholder.input node with // min/max/type attributes, so we increase the benefit of the first rule by one // so the tfl.quantize and tfl.dequantize ops will be inserted if it matches. -def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type), +def : Pat<(TF_PlaceholderInputOp TensorOf<[F16, F32, F64]>:$inputs, + $min, $max, $type), (TFL_DequantizeOp (TFL_QuantizeOp (TFL_InputOp $inputs), @@ -191,7 +201,8 @@ def : Pat<(TF_GatherV2Op $params, $indices, def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>; -def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>; +def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue), + (TFL_NotEqualOp $l, $r)>; def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>; @@ -252,7 +263,7 @@ def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)), def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>; -def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>; +def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>; def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>; @@ -308,3 +319,11 @@ def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>; def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>; def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>; + +def : Pat< + (TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size), + (TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>; + +def : Pat< + (TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size), + (TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc new file mode 100644 index 00000000000..01e54da1a61 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -0,0 +1,228 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass prepare the tflite fused ops for quantization. + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" + +//===----------------------------------------------------------------------===// +// The LoadQuantizationRecipe Pass. +// +namespace mlir { +namespace TFL { + +namespace { + +// This pass loads the quantization recipe for the TFLite ops to be quantized. +// Specifically, it extends the fused ops with their internal implementation as +// op regions. Each ops in the region produces results with element type +// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also +// defines the op quantization traits, which are used to propgate the +// quantization parameters by the following passes. +struct LoadQuantizationRecipe : public FunctionPass { + void runOnFunction() override; + + private: + void Initialize(LSTMOp lstm, OpBuilder* builder); + + // Create LSTM gates with different weights for input, recurrent and + // cell state, and also the layer normalization parameters. + Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec, + Value* rec_w, + llvm::Optional> cell, + Value* ln_w, Value* ln_bias, OpBuilder* builder); + + Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w, + Value* ln_bias, OpBuilder* builder); + + // Add the internal implementation of the LSTM to its regions. + void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder); + + StringAttr none_af; + StringAttr fc_format; + BoolAttr keep_dims; + Type int8; + Type int16; + ConstantOp none_cst; +}; + +void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) { + Type expressed_type = + lstm.input()->getType().cast().getElementType(); + Type int8_storage_type = builder->getIntegerType(8); + Type int16_storage_type = builder->getIntegerType(16); + auto flag = quant::QuantizationFlags::FlagValue::Signed; + int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger( + flag, /*integralWidth=*/8); + int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger( + flag, /*integralWidth=*/8); + int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger( + flag, /*integralWidth=*/16); + int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger( + flag, /*integralWidth=*/16); + auto any_int8 = quant::AnyQuantizedType::get( + flag, int8_storage_type, expressed_type, int8_min, int8_max); + auto any_int16 = quant::AnyQuantizedType::get( + flag, int16_storage_type, expressed_type, int16_min, int16_max); + + int8 = any_int8.castFromExpressedType(lstm.input()->getType()); + int16 = any_int16.castFromExpressedType(lstm.input()->getType()); +} + +Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in, + Value* ln_w, Value* ln_bias, + OpBuilder* builder) { + // Note that l2_normalization and add ops here are not the execution kernle + // implementation for layer_normalization and we just want to use them to + // model the quantization requirement. + auto l2_norm = builder->create(loc, int16, in, none_af); + auto add = builder->create(loc, int16, in, l2_norm, none_af); + return builder->create(loc, int16, add, ln_w, ln_bias, + none_af, fc_format, keep_dims); +} + +Operation* LoadQuantizationRecipe::CreateGate( + Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w, + llvm::Optional> cell, Value* ln_w, Value* ln_bias, + OpBuilder* builder) { + auto s1 = builder->create(loc, int16, in, in_w, none_cst, + none_af, fc_format, keep_dims); + auto s2 = builder->create(loc, int16, rec, rec_w, none_cst, + none_af, fc_format, keep_dims); + + AddNOp s4; + if (cell.hasValue()) { + auto s3 = builder->create(loc, int16, cell.getValue().first, + cell.getValue().second, none_af); + s4 = builder->create( + loc, int16, + llvm::ArrayRef( + {*s1.output().begin(), *s2.output().begin(), s3.output()})); + + } else { + s4 = builder->create( + loc, int16, + llvm::ArrayRef({*s1.output().begin(), *s2.output().begin()})); + } + + auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder); + + if (cell.hasValue()) { + return builder->create(loc, int16, s5->getResult(0)); + } else { + return builder->create(loc, int16, s5->getResult(0)); + } +} + +void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) { + Initialize(lstm, builder); + + Region region; + region.push_back(new Block); + builder->setInsertionPointToEnd(®ion.front()); + Location loc = lstm.getLoc(); + Type int32_type = builder->getIntegerType(32); + Type int32_tensor = builder->getTensorType(int32_type); + none_cst = builder->create(loc, builder->getNoneType(), + builder->getUnitAttr()); + + auto input_gate = CreateGate( + loc, lstm.input(), lstm.input_to_input_weights(), + lstm.input_activation_state(), lstm.recurrent_to_input_weights(), + llvm::Optional>( + {lstm.input_cell_state(), lstm.cell_to_input_weights()}), + lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder); + + auto forget_gate = CreateGate( + loc, lstm.input(), lstm.input_to_forget_weights(), + lstm.input_activation_state(), lstm.recurrent_to_forget_weights(), + llvm::Optional>( + {lstm.input_cell_state(), lstm.cell_to_forget_weights()}), + lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder); + + auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(), + lstm.input_activation_state(), + lstm.recurrent_to_cell_weights(), llvm::None, + lstm.cell_layer_norm_coefficients(), + lstm.cell_bias(), builder); + + auto forget_cell_state = builder->create( + loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af); + auto input_cell_state = builder->create( + loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af); + auto new_cell = builder->create(loc, int16, forget_cell_state.output(), + input_cell_state.output(), none_af); + + auto output_gate = CreateGate( + loc, lstm.input(), lstm.input_to_output_weights(), + lstm.input_activation_state(), lstm.recurrent_to_output_weights(), + llvm::Optional>( + {new_cell, lstm.cell_to_output_weights()}), + lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder); + + auto new_cell_tanh = builder->create(loc, int16, new_cell); + auto hidden_state = builder->create( + loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af); + auto act = builder->create( + loc, int8, hidden_state.output(), lstm.projection_weights(), + lstm.projection_bias(), none_af, fc_format, keep_dims); + + // TODO(fengliuai): define and register the op in the QuantOps Dialect. + OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0), + {int8}, {}); + builder->createOperation(return_state); + + lstm.internal().takeBody(region); +} + +void LoadQuantizationRecipe::runOnFunction() { + FuncOp func = getFunction(); + OpBuilder builder(func); + none_af = builder.getStringAttr("NONE"); + fc_format = builder.getStringAttr("DEFAULT"); + keep_dims = builder.getBoolAttr(false); + + func.walk([&](Operation* op) { + if (auto lstm = llvm::dyn_cast(op)) { + LoadForLSTMOp(lstm, &builder); + } + // Handles other ops. + }); +} + +} // namespace + +// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe +// pass. +std::unique_ptr CreateLoadQuantizationRecipePass() { + return absl::make_unique(); +} + +static PassRegistration pass( + "tfl-load-recipe", "Load TFL op quantization recipe"); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 716c8216433..35dd5e0a75d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -429,12 +429,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( } else if (auto tf_op = llvm::dyn_cast(op)) { if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() || tf_op.element_dtype().isF64() || + tf_op.element_dtype().isInteger(1) || tf_op.element_dtype().isInteger(8) || tf_op.element_dtype().isInteger(16) || tf_op.element_dtype().isInteger(32) || tf_op.element_dtype().isInteger(64))) { return tf_op.emitError( - "requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer " + "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " + "integer " "or 16-bit/32-bit/64-bit " "float type during TF Lite transformation pass"); } @@ -461,6 +463,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( auto c = ConvertTFTensorListPushBack(context); rewriter->setInsertionPoint(op); c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = TFL::ConvertTFTensorListLength(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); } else if (auto tf_op = llvm::dyn_cast(op)) { if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); UpdateWhileFunctionType(tf_op); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 51610832db6..9de40eb3cd6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -122,6 +122,8 @@ class OperandHasRank : Constraint< // Mul->Rsqrt->Sum->Square // Currently L2Normalization doesn't support activation function // in TFLite. +// TODO(karimnosseir): Add constraints that the kernel code assumes. +// constraint on axis and depth. def : Pat<(TFL_MulOp $operand1, (TFL_RsqrtOp (TFL_SumOp @@ -130,13 +132,14 @@ def : Pat<(TFL_MulOp $operand1, $keep_dims)), TFL_AF_None), (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (OperandHasRank<1> $operand1)]>; + [(EqualOperands $operand1, $square_operand)]>; // This pattern constructs L2NormalizationOp from // Div->sqrt->Sum->Square // Currently L2Normalization doesn't support activation function // in TFLite. +// TODO(karimnosseir): Add constraints that the kernel code assumes. +// constraint on axis and depth. def : Pat<(TFL_DivOp $operand1, (TFL_SqrtOp (TFL_SumOp @@ -145,5 +148,4 @@ def : Pat<(TFL_DivOp $operand1, $keep_dims)), TFL_AF_None), (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (OperandHasRank<1> $operand1)]>; + [(EqualOperands $operand1, $square_operand)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index e3dabb7a48d..7cb89c4219c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -18,6 +18,14 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" def FalseBoolAttr : AttrConstraint>; +def NonOpaqueElementsAttr : ElementsAttrBase< + CPred<"!$_self.isa()">, + "non-opaque constant tensor">; + +// Convert to std constant for statically shaped, non-opaque constants. +def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), + [(AnyStaticShapeTensor $res)]>; + // Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic // operations. Specifically, performs the following calculation: // @@ -81,8 +89,8 @@ class TFi32 : ConstantAttr(v)>; def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse), (TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp /*start=*/(TF_RankOp $b), - /*limit=*/(ConstantOp TFi32<0>), - /*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), + /*limit=*/(TF_ConstOp TFi32<0>), + /*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $at, ConstBoolAttrTrue)>; // Matmul with transpose on a to matmul with explicit transpose op and a not @@ -90,10 +98,12 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse), def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt), (TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp /*start=*/(TF_RankOp $a), - /*limit=*/(ConstantOp TFi32<0>), - /*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b, + /*limit=*/(TF_ConstOp TFi32<0>), + /*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b, ConstBoolAttrFalse, $bt)>; +// Partially supported in TFLite, treated as passthrough IdentityOp +def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>; def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>; def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 7c7983ae254..2b91b2f4177 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -246,7 +247,8 @@ struct ConvertTFConvOp : public RewritePattern { filter_type.getShape()); auto bias_type = rewriter.getTensorType({bias_dim}, elem_type); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = rewriter.create(op->getLoc(), bias_type, bias_attr); + auto bias = + rewriter.create(op->getLoc(), bias_type, bias_attr); auto *conv_state = static_cast(state.get()); auto conv_op = static_cast(this)->createTFLOp( @@ -297,7 +299,7 @@ class ConvertTFConv2D : public ConvertTFConvOp { rewriter.getIntegerType(32)); auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm)); - auto perm_op = rewriter.create(loc, perm_type, perm_attr); + auto perm_op = rewriter.create(loc, perm_type, perm_attr); // Create tensor type for the transpose result. auto filter_type = filter->getType().cast(); @@ -366,7 +368,7 @@ class ConvertTFDepthwiseConv2dNative auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64)); auto shape_attr = DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape)); - auto shape = rewriter.create(loc, shape_type, shape_attr); + auto shape = rewriter.create(loc, shape_type, shape_attr); return rewriter.create(loc, result_type, filter, shape); } @@ -377,6 +379,11 @@ class ConvertTFDepthwiseConv2dNative void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); + + patterns.insert, + ConvertTFBatchMatMulOp>(&getContext()); + applyPatternsGreedily(func, patterns); + // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsGreedily` method, which would otherwise removes the diff --git a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td index 764f8e95f55..167a4be3579 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +def CreateTFShapeOp : NativeCodeCall< + "$_builder.create($0->getLoc(), $1, $2)">; + //===----------------------------------------------------------------------===// // TensorList transformation patterns. // Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT) @@ -34,3 +38,11 @@ def ConvertTFTensorListStack : Pat< def ConvertTFTensorListGetItem : Pat< (TF_TensorListGetItemOp $input, $index, $element_shape), (TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>; + +// TensorListLength is equivalent to the size of the first dimension of the +// input tensorlist, rewrite it to a combination of Gather and Shape op. +def ConvertTFTensorListLength: Pat< + (TF_TensorListLengthOp:$old_value $input), + (TF_GatherOp + (CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue), + (ConstantOp ConstantAttr), ConstBoolAttrTrue)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc new file mode 100644 index 00000000000..50b644f9635 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc @@ -0,0 +1,309 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TFL { + +namespace { +// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out +// of the inputs, matmul them individually, then stack them all back together at +// the end. +struct UnrollBatchMatMulPass : public FunctionPass { + void runOnFunction() override; +}; + +void UnrollBatchMatMulPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert, + ConvertTFBatchMatMulOp>(&getContext()); + applyPatternsGreedily(func, patterns); +} + +} // namespace + +template +TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( + Value* value, ArrayRef shape, Type elementType, Location loc, + PatternRewriter& rewriter) { + int64_t shape_rank = shape.size(); + auto shapeSpecType = + rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64)); + Type resultType = rewriter.getTensorType(shape, elementType); + auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape); + auto shapeTensor = + rewriter.create(loc, shapeSpecType, constant_attr); + return rewriter.create(loc, resultType, /*tensor=*/value, + /*shape=*/shapeTensor); +} + +template +std::vector ConvertTFBatchMatMulOp::sliceInput( + Value* value, int batch_size, Location loc, PatternRewriter& rewriter) { + RankedTensorType tensorType = value->getType().cast(); + Type elementType = tensorType.getElementType(); + + int rank = tensorType.getShape().size(); + int num_rows = tensorType.getShape()[rank - 2]; + int num_cols = tensorType.getShape()[rank - 1]; + + // Reshape to rank-3 Tensor with first dimension as the batch size. + auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols}, + elementType, loc, rewriter); + + SmallVector sliceSize = {1, num_rows, num_cols}; + + std::vector sliced; + Type int64Type = rewriter.getIntegerType(64); + Type sliceResultType = rewriter.getTensorType(sliceSize, elementType); + + // Slice along each batch index and remember the slice output for future + // use. + for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + auto vector3Type = rewriter.getTensorType({3}, int64Type); + + auto begin_attr = + DenseElementsAttr::get(vector3Type, {batch_idx, 0, 0}); + auto size_attr = DenseElementsAttr::get(vector3Type, sliceSize); + auto begin = rewriter.create(loc, vector3Type, begin_attr); + auto size = rewriter.create(loc, vector3Type, size_attr); + auto sliceOp = + rewriter.create(loc, sliceResultType, + /*input=*/reshapeOp.output(), begin, size); + + // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows, + // num_cols] + auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols}, + elementType, loc, rewriter); + + sliced.emplace_back(squeezeOp.output()); + } + return sliced; +} + +template +TF::TransposeOp ConvertTFBatchMatMulOp::createTransposeOp( + Value* value, Location loc, PatternRewriter& rewriter) { + auto valueType = value->getType().cast(); + auto shape = valueType.getShape(); + int dims = shape.size(); + + std::vector perm(dims); + for (int i = 0; i < dims - 2; i++) { + perm[i] = i; + } + perm[dims - 2] = dims - 1; + perm[dims - 1] = dims - 2; + + auto perm_type = rewriter.getTensorType({static_cast(perm.size())}, + rewriter.getIntegerType(32)); + + auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm)); + auto perm_op = rewriter.create(loc, perm_type, perm_attr); + + std::vector transposed_shape(shape.begin(), shape.end()); + int64_t r = transposed_shape[dims - 1]; + int64_t c = transposed_shape[dims - 2]; + + transposed_shape[dims - 1] = c; + transposed_shape[dims - 2] = r; + + auto transposed_type = + rewriter.getTensorType(transposed_shape, valueType.getElementType()); + return rewriter.create(loc, transposed_type, value, perm_op); +} + +template +TF::PackOp ConvertTFBatchMatMulOp::createMatMulOps( + const std::vector& sliced_lhs, + const std::vector& sliced_rhs, const tensorflow::MatMulBCast& bcast, + int rows, int cols, Type elementType, Location loc, + PatternRewriter& rewriter) { + auto matmulType = rewriter.getTensorType({rows, cols}, elementType); + + std::vector matmuls; + for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { + int lhs_batch_idx, rhs_batch_idx; + if (bcast.IsBroadcastingRequired()) { + lhs_batch_idx = bcast.x_batch_indices()[batch_idx]; + rhs_batch_idx = bcast.y_batch_indices()[batch_idx]; + } else { + lhs_batch_idx = batch_idx; + rhs_batch_idx = batch_idx; + } + auto false_attr = rewriter.getBoolAttr(false); + auto matmul = rewriter.create(loc, matmulType, + /*a=*/sliced_lhs[lhs_batch_idx], + /*b=*/sliced_rhs[rhs_batch_idx], + /*transpose_a=*/false_attr, + /*transpose_b=*/false_attr); + matmuls.emplace_back(matmul.product()); + } + + // Combine the result of each individual MatMul into a rank-3 Tensor. + Type packedType = rewriter.getTensorType( + {bcast.output_batch_size(), rows, cols}, elementType); + + auto N = rewriter.getI64IntegerAttr(matmuls.size()); + auto axis = rewriter.getI64IntegerAttr(0); + return rewriter.create(loc, packedType, + /*values=*/matmuls, N, axis); +} + +template +PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( + BatchMatMulOpType op, PatternRewriter& rewriter) const { + Value* input_lhs = op.x(); + Value* input_rhs = op.y(); + + if (!input_lhs->getType().isa()) { + // LHS must be a ranked tensor type + return this->matchFailure(); + } + if (!input_rhs->getType().isa()) { + // RHS must be a ranked tensor type + return this->matchFailure(); + } + + auto lhs_type = input_lhs->getType().cast(); + auto rhs_type = input_rhs->getType().cast(); + + auto elementType = lhs_type.getElementType(); + + if (elementType != rhs_type.getElementType()) { + // The element type of LHS must be the same with element type of RHS + return this->matchFailure(); + } + + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + Location loc = op.getLoc(); + + // Transpose LHS input if necessary. + if (op.adj_x()) { + input_lhs = createTransposeOp(input_lhs, loc, rewriter); + + lhs_type = input_lhs->getType().cast(); + lhs_shape = lhs_type.getShape(); + } + + // Transpose RHS input if necessary. + if (op.adj_y()) { + input_rhs = createTransposeOp(input_rhs, loc, rewriter); + + rhs_type = input_rhs->getType().cast(); + rhs_shape = rhs_type.getShape(); + } + + // Ensure that input ranks are at least 2 and batch shapes are + // broadcastable. + const int dims_a = lhs_shape.size(); + const int dims_b = rhs_shape.size(); + if (dims_a < 2 || dims_b < 2) { + // Both inputs must have rank >= 2 + return this->matchFailure(); + } + + if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { + // Input dimensions must be compatible for multipication. + return this->matchFailure(); + } + + if (dims_a == 2 && dims_b == 2) { + // When both inputs are matrices, just replace the op to a matmul op. + Type resultType = + rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType); + auto false_attr = rewriter.getBoolAttr(false); + rewriter.replaceOpWithNewOp(op, resultType, + /*a=*/input_lhs, + /*b=*/input_rhs, + /*transpose_a=*/false_attr, + /*transpose_b=*/false_attr); + return this->matchSuccess(); + } + + tensorflow::MatMulBCast bcast(absl::InlinedVector( + lhs_shape.begin(), lhs_shape.end()), + absl::InlinedVector( + rhs_shape.begin(), rhs_shape.end())); + + if (!bcast.IsValid()) { + // Input batch dimensions must be broadcastable + return this->matchFailure(); + } + + // Compute slices for each batch in the LHS and RHS. + std::vector sliced_lhs = + sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter); + std::vector sliced_rhs = + sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter); + + // Compute (single batch) MatMul for each output batch. The MatMul outputs + // are then packed together into one output Tensor. + auto packOp = + createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2], + rhs_shape[dims_b - 1], elementType, loc, rewriter); + + // Reshape the rank-3 Tensor into the correct output shape. + const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes(); + std::vector resultShape(resultBatchShape.begin(), + resultBatchShape.end()); + resultShape.push_back(lhs_shape[dims_a - 2]); + resultShape.push_back(rhs_shape[dims_b - 1]); + + auto reshapeOp = + createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter); + rewriter.replaceOp(op, reshapeOp.output()); + return this->matchSuccess(); +} + +static PassRegistration pass( + "tfl-unroll-batch-matmul", + "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops."); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h new file mode 100644 index 00000000000..d4b46eabf7d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TFL { + +// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not +// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape, +// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops. +template +class ConvertTFBatchMatMulOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef shape, + Type elementType, Location loc, + PatternRewriter& rewriter); + + static std::vector sliceInput(Value* value, int batch_size, + Location loc, + PatternRewriter& rewriter); + + static TF::TransposeOp createTransposeOp(Value* value, Location loc, + PatternRewriter& rewriter); + + static TF::PackOp createMatMulOps(const std::vector& sliced_lhs, + const std::vector& sliced_rhs, + const tensorflow::MatMulBCast& bcast, + int rows, int cols, Type elementType, + Location loc, PatternRewriter& rewriter); + + PatternMatchResult matchAndRewrite(BatchMatMulOpType op, + PatternRewriter& rewriter) const override; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc new file mode 100644 index 00000000000..d98101bd4cb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -0,0 +1,456 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef shape, + int32_t val, mlir::Location location) { + auto type = builder->getTensorType(shape, builder->getIntegerType(32)); + auto attr = DenseElementsAttr::get(type, val); + return builder->create(location, type, attr); +} + +Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef shape, + float val, mlir::Location location) { + auto type = builder->getTensorType(shape, builder->getF32Type()); + auto attr = DenseElementsAttr::get(type, val); + return builder->create(location, type, attr); +} + +Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef shape, + ArrayRef values, mlir::Location location) { + auto type = builder->getTensorType(static_cast(shape.size()), + builder->getIntegerType(64)); + auto attr = DenseElementsAttr::get(type, values); + return builder->create(location, type, attr); +} + +Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) { + return builder->create(location, builder->getNoneType(), + builder->getUnitAttr()); +} + +Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose, + RankedTensorType type, mlir::Location location) { + // Create a constant op for transpose permutation. + SmallVector perm = {1, 0}; + auto perm_op = CreateI64DenseConst(builder, perm, perm, location); + + // Create tensor type for the transpose result. + auto transpose_type = type; + auto transpose_shape = functional::map( + [transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); }, + perm); + auto elem_type = transpose_type.getElementType(); + auto result_type = builder->getTensorType(transpose_shape, elem_type); + + return builder->create(location, result_type, + value_to_transpose, perm_op); +} + +Value* SliceRankedTensor(OpBuilder* builder, Value* input, + ArrayRef begin_shape, + ArrayRef begin_values, + ArrayRef size_shape, + ArrayRef size_values, + mlir::Location location) { + // Create a dense constant op for slice's begin + auto slice_i2c_begin = + CreateI64DenseConst(builder, begin_shape, begin_values, location); + + // Create a dense constant op for slice's size + auto slice_i2c_size = + CreateI64DenseConst(builder, size_shape, size_values, location); + + return builder->create( + location, + builder->getTensorType( + size_values, + input->getType().cast().getElementType()), + input, slice_i2c_begin, slice_i2c_size); +} + +} // namespace + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() { + SmallVector begin_i2c_values = {0, 0}; + input2cell_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values, + weight_slice_shape_, weight_slice_size_input_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() { + SmallVector begin_i2i_values = {n_cell_, 0}; + input2input_ = couple_input_forget_gates_ + ? none_ + : SliceRankedTensor(&builder_, weight_transposed_, + weight_slice_shape_, begin_i2i_values, + weight_slice_shape_, + weight_slice_size_input_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() { + int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_; + SmallVector begin_i2f_values = {input_forget_start, 0}; + input2forget_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values, + weight_slice_shape_, weight_slice_size_input_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() { + int input_output_start = + couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_; + SmallVector begin_i2o_values = {input_output_start, 0}; + input2output_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values, + weight_slice_shape_, weight_slice_size_input_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() { + SmallVector begin_rec2c_values = {0, n_input_}; + rec2cell_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values, + weight_slice_shape_, weight_slice_size_recurrent_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() { + SmallVector begin_rec2i_values = {n_cell_, n_input_}; + rec2input_ = couple_input_forget_gates_ + ? none_ + : SliceRankedTensor(&builder_, weight_transposed_, + weight_slice_shape_, begin_rec2i_values, + weight_slice_shape_, + weight_slice_size_recurrent_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() { + int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_; + SmallVector begin_rec2f_values = {rec_forget_start, n_input_}; + rec2forget_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values, + weight_slice_shape_, weight_slice_size_recurrent_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() { + int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_; + SmallVector begin_rec2o_values = {rec_output_start, n_input_}; + rec2output_ = SliceRankedTensor( + &builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values, + weight_slice_shape_, weight_slice_size_recurrent_values_, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() { + SmallVector begin_bias2c_values = {0}; + bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_, + begin_bias2c_values, bias_slice_shape_, + bias_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() { + SmallVector begin_bias2i_values = {n_cell_}; + bias2input_ = + couple_input_forget_gates_ + ? none_ + : SliceRankedTensor(&builder_, bias_, bias_slice_shape_, + begin_bias2i_values, bias_slice_shape_, + bias_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() { + int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_; + SmallVector begin_bias2f_values = {bias_forget_start}; + bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_, + begin_bias2f_values, bias_slice_shape_, + bias_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() { + int bias_output_start = + couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_; + SmallVector begin_bias2o_values = {bias_output_start}; + bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_, + begin_bias2o_values, bias_slice_shape_, + bias_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() { + SmallVector projection_slice_shape = { + 1, num_cols_projection_transposed_}; + SmallVector projection_slice_size_values = {n_output_, n_cell_}; + SmallVector projection_slice_begin_values = {0, 0}; + proj_weight_ = + !projection_ + ? none_ + : SliceRankedTensor( + &builder_, projection_transposed_, projection_slice_shape, + projection_slice_begin_values, projection_slice_shape, + projection_slice_size_values, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() { + proj_bias_ = !projection_type_ + ? none_ + : CreateF32SplatConst(&builder_, {n_output_}, 0, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() { + input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0, + fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() { + input_cell_state_ = + CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc()); +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() { + cell_layer_norm_coefficients_ = none_; +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() { + input_layer_norm_coefficients_ = none_; +} + +void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() { + forget_layer_norm_coefficients_ = none_; +} +void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() { + output_layer_norm_coefficients_ = none_; +} + +void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() { + // Transpose both weight and projection. + weight_transposed_ = + Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc()); + projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_, + fused_func_op_.getLoc()); + + none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc()); + // Extract input to cifg gates via slicing the weight tensor + SetWeightForInputToCellGate(); + SetWeightForInputToInputGate(); + SetWeightForInputToForgetGate(); + SetWeightForInputToOutputGate(); + + // Extract recurrent to cifg gates via slicing the weight tensor + SetWeightForRecurrentToCellGate(); + SetWeightForRecurrentToInputGate(); + SetWeightForRecurrentToForgetGate(); + SetWeightForRecurrentToOutputGate(); + + // Extract bias to cifg gates via slicing the bias tensor + SetBiasToCellGate(); + SetBiasToInputGate(); + SetBiasToForgetGate(); + SetBiasToOutputGate(); + + // Extract projection and set an empty projection bias + SetProjection(); + SetProjectionBias(); + + // Set the variable tensors + SetInputActivationState(); + SetInputCellState(); + + // Extract the layer norm coefficients + SetCellLayerNormCoefficients(); + SetInputLayerNormCoefficients(); + SetForgetLayerNormCoefficients(); + SetOutputLayerNormCoefficients(); +} + +void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() { + // https://github.com/tensorflow/community/pull/113 + auto attr = fused_func_op_.getAttrOfType("tf_.implements"); + if (!attr) { + fused_func_op_.setAttr("tf._implements", + builder_.getStringAttr(GetCompositeOpName())); + } + SmallVector output_shape{1, n_output_}; + auto input_types = fused_func_op_.getType().getInputs(); + auto output_type = builder_.getTensorType( + output_shape, + input_->getType().cast().getElementType()); + fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type, + fused_func_op_.getContext())); +} + +void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { + // Update the func signature, based on output shape. + // The func will ultimately return the output of the fused + // LSTM op. + UpdateFuncSignature(); + + // Transoform the weights, projection, bias and layer norm coefficients + // to generate operands for the TFL fused LSTM op. + GenerateFusedOpOperands(); + + // Create the fused LSTM op. + SmallVector output_shape = {1, n_output_}; + auto result_type = builder_.getTensorType( + output_shape, + input_->getType().cast().getElementType()); + lstm_ = builder_.create( + fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_, + input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_, + rec2output_, /*cell_to_input_weights*/ none_, + /*cell_to_forget_weights*/ none_, + /*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_, + bias2output_, proj_weight_, proj_bias_, input_activation_state_, + input_cell_state_, input_layer_norm_coefficients_, + forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, + output_layer_norm_coefficients_, builder_.getStringAttr("TANH"), + builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0), + builder_.getStringAttr("FULL")); + + builder_.create(fused_func_op_.getLoc(), lstm_.getResult()); +} + +LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { + num_gates_ = couple_input_forget_gates_ ? 3 : 4; + + input_ = fused_func_op_.getArgument(0); + bias_ = fused_func_op_.getArgument(2); + + weight_ = fused_func_op_.getArgument(1); + weight_type_ = weight_->getType().cast(); + + if (weight_type_.getRank() != 2) { + return fused_func_op_.emitError() << "The weight tensor was not of rank 2"; + } + + if (weight_type_.getDimSize(1) % num_gates_ != 0) { + return fused_func_op_.emitError() + << "Invalid dimension 1 of weight tensor, " + "should be divisible by the number of gates"; + } + n_cell_ = weight_type_.getDimSize(1) / num_gates_; + + projection_ = fused_func_op_.getArgument(3); + projection_type_ = projection_->getType().cast(); + if (projection_type_.getRank() != 2) { + n_output_ = n_cell_; + } else { + n_output_ = projection_type_.getDimSize(1); + } + n_input_ = weight_type_.getDimSize(0) - n_output_; + num_cols_weight_transposed_ = weight_type_.getDimSize(0); + num_cols_projection_transposed_ = projection_type_.getDimSize(0); + + bias_slice_shape_ = {n_cell_}; + bias_size_values_ = {n_cell_}; + weight_slice_shape_ = {1, num_cols_weight_transposed_}; + weight_slice_size_input_values_ = {n_cell_, n_input_}; + weight_slice_size_recurrent_values_ = {n_cell_, n_output_}; + + return success(); +} + +LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() { + if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) { + return fused_func_op_.emitError() + << "Specified LayerNormalizedLSTMCellSimple was not of the expected " + "interface and cannot not be converted to the fused LSTM op"; + } + + layer_norm_scale_ = fused_func_op_.getArgument(4); + layer_norm_scale_type_ = + layer_norm_scale_->getType().cast(); + if (layer_norm_scale_type_.getRank() != 1) { + return fused_func_op_.emitError() + << "The layer_norm_scale tensor was not of rank 1"; + } + layer_norm_slice_shape_ = {n_cell_}; + layer_norm_size_values_ = {n_cell_}; + + return success(); +} + +void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM:: + SetCellLayerNormCoefficients() { + SmallVector begin_cell_layer_norm_values = {0}; + cell_layer_norm_coefficients_ = + SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_, + begin_cell_layer_norm_values, layer_norm_slice_shape_, + layer_norm_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM:: + SetInputLayerNormCoefficients() { + SmallVector begin_input_layer_norm_values = {n_cell_}; + input_layer_norm_coefficients_ = + couple_input_forget_gates_ + ? none_ + : SliceRankedTensor( + &builder_, layer_norm_scale_, layer_norm_slice_shape_, + begin_input_layer_norm_values, layer_norm_slice_shape_, + layer_norm_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM:: + SetForgetLayerNormCoefficients() { + SmallVector begin_forget_layer_norm_values = {2 * n_cell_}; + forget_layer_norm_coefficients_ = + SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_, + begin_forget_layer_norm_values, layer_norm_slice_shape_, + layer_norm_size_values_, fused_func_op_.getLoc()); +} + +void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM:: + SetOutputLayerNormCoefficients() { + SmallVector begin_output_layer_norm_values = {3 * n_cell_}; + output_layer_norm_coefficients_ = + SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_, + begin_output_layer_norm_values, layer_norm_slice_shape_, + layer_norm_size_values_, fused_func_op_.getLoc()); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h new file mode 100644 index 00000000000..e59b2b662dd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -0,0 +1,214 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +constexpr char kLstmCellSimple[] = "LSTMCellSimple"; +constexpr char kLayerNormalizedLstmCellSimple[] = + "LayerNormalizedLstmCellSimple"; + +// A utility class that enables the conversion of the LSTMCellSimple composite +// op into a fused TFL LSTM op. The fused op is contained within a FuncOp +// that also contains other supporting ops needed to construct the operands for +// the fused op. The caller provides the containing FuncOp as input with +// arguments specifying the input, weight, projection and bias. +// The weight, pprojection, bias and layer norm scale all need to be +// RankedTensorType. +// This class sets the layer norm coefficients to NoneType. +class ConvertLSTMCellSimpleToFusedLSTM { + public: + // TODO(b/140053256): The couple_input_forget_gates should be specified on + // FuncOp as an attribute. + explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op, + bool couple_input_forget_gates) + : fused_func_op_(fused_func_op), + couple_input_forget_gates_(couple_input_forget_gates), + builder_(fused_func_op.getBody()) {} + + // not copyable. + ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) = + delete; + ConvertLSTMCellSimpleToFusedLSTM& operator=( + const ConvertLSTMCellSimpleToFusedLSTM&) = delete; + virtual ~ConvertLSTMCellSimpleToFusedLSTM() {} + + // verify input func op arguments and initialize internal state. + virtual LogicalResult Initialize(); + + virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; } + + // Rewrite the func body with constructed fused lstm. + void RewriteFunc(); + + protected: + void UpdateFuncSignature(); + void GenerateFusedOpOperands(); + + void SetWeightForInputToCellGate(); + void SetWeightForInputToInputGate(); + void SetWeightForInputToForgetGate(); + void SetWeightForInputToOutputGate(); + + void SetWeightForRecurrentToCellGate(); + void SetWeightForRecurrentToInputGate(); + void SetWeightForRecurrentToForgetGate(); + void SetWeightForRecurrentToOutputGate(); + + void SetBiasToCellGate(); + void SetBiasToInputGate(); + void SetBiasToForgetGate(); + void SetBiasToOutputGate(); + + void SetProjection(); + void SetProjectionBias(); + + void SetInputActivationState(); + void SetInputCellState(); + + virtual void SetCellLayerNormCoefficients(); + virtual void SetInputLayerNormCoefficients(); + virtual void SetForgetLayerNormCoefficients(); + virtual void SetOutputLayerNormCoefficients(); + + // specified state + FuncOp fused_func_op_; + Value* input_; + Value* weight_; + Value* bias_; + Value* projection_; + bool couple_input_forget_gates_; + + // internal state + Value* weight_transposed_; + Value* projection_transposed_; + RankedTensorType weight_type_; + RankedTensorType projection_type_; + int num_gates_; + int n_cell_; + int n_output_; + int n_input_; + int num_cols_weight_transposed_; + int num_cols_projection_transposed_; + + // input -> cifg + Value* input2input_; + Value* input2forget_; + Value* input2cell_; + Value* input2output_; + + // reccurrent -> cifg + Value* rec2input_; + Value* rec2forget_; + Value* rec2cell_; + Value* rec2output_; + + // bias -> cifg + Value* bias2input_; + Value* bias2forget_; + Value* bias2cell_; + Value* bias2output_; + + // projection + Value* proj_weight_; + Value* proj_bias_; + + // state + Value* input_activation_state_; + Value* input_cell_state_; + + // layer norm coefficients + Value* input_layer_norm_coefficients_; + Value* forget_layer_norm_coefficients_; + Value* cell_layer_norm_coefficients_; + Value* output_layer_norm_coefficients_; + + mlir::TFL::LSTMOp lstm_; + + Value* none_; + SmallVector bias_slice_shape_; + SmallVector bias_size_values_; + SmallVector weight_slice_shape_; + SmallVector weight_slice_size_input_values_; + SmallVector weight_slice_size_recurrent_values_; + OpBuilder builder_; +}; + +// A utility class that enables the conversion of the +// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The +// fused op is contained within a FuncOp that also contains other supporting ops +// needed to construct the operands for the fused op. The caller provides the +// containing FuncOp as input with arguments specifying the input, weight, +// projection, bias and layer norm scale. The weight, pprojection, bias and +// layer norm scale all need to be RankedTensorType. +// This class overrides the layer norm coefficient setters from the base class. +class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM + : public ConvertLSTMCellSimpleToFusedLSTM { + public: + // TODO(b/140053256): The couple_input_forget_gates should be specified on + // FuncOp as an attribute. + explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( + mlir::FuncOp fused_func_op, bool couple_input_forget_gates) + : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op, + couple_input_forget_gates) {} + + // not copyable. + ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( + const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; + ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=( + const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; + ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {} + + llvm::StringRef GetCompositeOpName() override { + return kLayerNormalizedLstmCellSimple; + } + + LogicalResult Initialize() override; + + protected: + void SetCellLayerNormCoefficients() override; + void SetInputLayerNormCoefficients() override; + void SetForgetLayerNormCoefficients() override; + void SetOutputLayerNormCoefficients() override; + + private: + // specified state + Value* layer_norm_scale_; + + // internal state + RankedTensorType layer_norm_scale_type_; + SmallVector layer_norm_slice_shape_; + SmallVector layer_norm_size_values_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc new file mode 100644 index 00000000000..56d6ab1f8ab --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -0,0 +1,222 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace TFL { + +FuncOp createFusedFunc(mlir::Builder* builder) { + SmallVector input_shape{1, 2}; + SmallVector weight_shape{3, 12}; + SmallVector bias_shape{2}; + SmallVector projection_shape{1, 2}; + SmallVector layer_norm_scale{4}; + SmallVector output_shape{1, 2}; + auto input_type = builder->getTensorType(input_shape, builder->getF32Type()); + auto weight_type = + builder->getTensorType(weight_shape, builder->getF32Type()); + auto bias_type = builder->getTensorType(bias_shape, builder->getF32Type()); + auto projection_type = + builder->getTensorType(projection_shape, builder->getF32Type()); + auto layer_norm_scale_type = + builder->getTensorType(layer_norm_scale, builder->getF32Type()); + auto output_type = + builder->getTensorType(output_shape, builder->getF32Type()); + SmallVector input_types{input_type, weight_type, bias_type, + projection_type, + layer_norm_scale_type}; + auto func_type = builder->getFunctionType(input_types, output_type); + + auto func = + FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"), + builder->getContext()), + "fused_func", func_type, {}); + func.addEntryBlock(); + return func; +} + +// TODO(ashwinm): Revisit if this test should be moved to a test pass +// with FileCheck test after the pass that consumes the lstm_utils to stack +// the layers. +class LstmUtilsTest : public ::testing::Test { + protected: + LstmUtilsTest() {} + + void SetUp() override { + builder_ = std::unique_ptr(new Builder(&context_)); + fused_lstm_func_ = createFusedFunc(builder_.get()); + } + + void TearDown() override { + fused_lstm_func_.erase(); + builder_.reset(); + } + FuncOp fused_lstm_func_; + mlir::MLIRContext context_; + std::unique_ptr builder_; +}; + +TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { + mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, false); + + auto result = convert.Initialize(); + EXPECT_FALSE(failed(result)); + + convert.RewriteFunc(); + fused_lstm_func_.dump(); + + // verify transpose + EXPECT_EQ( + fused_lstm_func_.getAttrOfType("tf._implements").getValue(), + convert.GetCompositeOpName()); + EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5); + EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1); + + auto transpose_op = fused_lstm_func_.getBody().front().begin(); + transpose_op++; + EXPECT_EQ(transpose_op->getOperand(0) + ->getType() + .cast() + .getDimSize(0), + 3); + EXPECT_EQ(transpose_op->getOperand(0) + ->getType() + .cast() + .getDimSize(1), + 12); + EXPECT_EQ( + transpose_op->getResult(0)->getType().cast().getDimSize( + 0), + 12); + EXPECT_EQ( + transpose_op->getResult(0)->getType().cast().getDimSize( + 1), + 3); + + auto return_op = fused_lstm_func_.getBody().back().rbegin(); + EXPECT_EQ(return_op->getName().getStringRef(), + mlir::ReturnOp::getOperationName()); + return_op++; + EXPECT_EQ(return_op->getName().getStringRef(), + mlir::TFL::LSTMOp::getOperationName()); + EXPECT_EQ(return_op->getNumOperands(), 24); + EXPECT_EQ(return_op->getNumResults(), 1); + // cifg = false, so input2input is not None. + EXPECT_FALSE(return_op->getOperand(1)->getType().isa()); + // input layer norm is None + EXPECT_TRUE(return_op->getOperand(20)->getType().isa()); + // proj_bias is F32 + EXPECT_TRUE(return_op->getOperand(17) + ->getType() + .cast() + .getElementType() + .isF32()); + + EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1); + auto output_types = fused_lstm_func_.getType().getResults(); + SmallVector output_shape{1, 2}; + EXPECT_EQ(output_types[0].cast().getShape().size(), + output_shape.size()); + for (int i = 0; i < output_shape.size(); i++) { + EXPECT_EQ(output_types[0].cast().getDimSize(i), + output_shape[i]); + } +} + +TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) { + mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, true); + + auto result = convert.Initialize(); + EXPECT_FALSE(failed(result)); + + convert.RewriteFunc(); + fused_lstm_func_.dump(); + + auto it = fused_lstm_func_.getBody().back().rbegin(); + EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName()); + it++; + EXPECT_EQ(it->getName().getStringRef(), + mlir::TFL::LSTMOp::getOperationName()); + EXPECT_EQ(it->getNumOperands(), 24); + EXPECT_EQ(it->getNumResults(), 1); + // cifg = true, so input2input is None. + EXPECT_TRUE(it->getOperand(1)->getType().isa()); +} + +TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { + mlir::TFL::ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM convert( + fused_lstm_func_, false); + + auto result = convert.Initialize(); + EXPECT_FALSE(failed(result)); + + convert.RewriteFunc(); + fused_lstm_func_.dump(); + + EXPECT_EQ( + fused_lstm_func_.getAttrOfType("tf._implements").getValue(), + convert.GetCompositeOpName()); + EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5); + EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1); + + auto it = fused_lstm_func_.getBody().back().rbegin(); + EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName()); + it++; + EXPECT_EQ(it->getName().getStringRef(), + mlir::TFL::LSTMOp::getOperationName()); + EXPECT_EQ(it->getNumOperands(), 24); + EXPECT_EQ(it->getNumResults(), 1); + // cifg = false, so input2input is not None. + EXPECT_FALSE(it->getOperand(1)->getType().isa()); + + // input layer norm + EXPECT_FALSE(it->getOperand(20)->getType().isa()); + EXPECT_EQ( + it->getOperand(20)->getType().cast().getShape().size(), + 1); + EXPECT_EQ( + it->getOperand(20)->getType().cast().getDimSize(0), 3); + + EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1); + auto output_types = fused_lstm_func_.getType().getResults(); + SmallVector output_shape{1, 2}; + EXPECT_EQ(output_types[0].cast().getShape().size(), + output_shape.size()); + for (int i = 0; i < output_shape.size(); i++) { + EXPECT_EQ(output_types[0].cast().getDimSize(i), + output_shape[i]); + } +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD new file mode 100644 index 00000000000..5291cf3b141 --- /dev/null +++ b/tensorflow/compiler/mlir/python/BUILD @@ -0,0 +1,11 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files( + ["mlir.i"], + visibility = [ + "//tensorflow/python:__subpackages__", + ], +) diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i new file mode 100644 index 00000000000..03273357b2b --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ + +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" + +namespace tensorflow { +namespace swig { + +// Simple wrapper to support tf.mlir.experimental.convert_graph_def. +// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before +// returning it as a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +string ImportGraphDef(const string &proto, TF_Status* status) { + GraphDef graphdef; + auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + GraphDebugInfo debug_info; + NodeSpecs specs; + mlir::MLIRContext context; + auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); + if (!module.ok()) { + Set_TF_Status_from_Status(status, module.status()); + return "// error"; + } + + return MlirModuleToString(*module.ConsumeValueOrDie()); +} + +} // namespace swig +} // namespace tensorflow + +%} + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::swig; +%unignore tensorflow::swig::ImportGraphDef; + +// Wrap this function +namespace tensorflow { +namespace swig { +static string ImportGraphDef(const string &graphdef, TF_Status* status); +} // namespace swig +} // namespace tensorflow + +%insert("python") %{ +def import_graphdef(graphdef): + return str(ImportGraphDef(str(graphdef).encode('utf-8'))); +%} + +%unignoreall diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 4b64dfcb9dd..b54aef1e42a 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,5 +1,5 @@ load("@local_config_mlir//:tblgen.bzl", "gentbl") -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary") package( default_visibility = [":friends"], @@ -123,21 +123,6 @@ cc_library( "ir/tf_ops.cc.inc", "ir/tf_ops.h.inc", "ir/tf_types.cc", - "transforms/bridge.cc", - "transforms/bridge_pass.cc", - "transforms/cluster_formation.cc", - "transforms/cluster_outlining.cc", - "transforms/executor_island_coarsening.cc", - "transforms/functional_control_flow_to_cfg.cc", - "transforms/generated_canonicalize.inc", - "transforms/generated_optimize.inc", - "transforms/graph_pruning.cc", - "transforms/optimize.cc", - "transforms/raise_control_flow.cc", - "transforms/tpu_cluster_formation.cc", - "transforms/tpu_rewrite_pass.cc", - "translate/control_to_executor_dialect.cc", - "translate/executor_to_control_dialect.cc", ], hdrs = [ "ir/control_flow_ops.h", @@ -153,11 +138,11 @@ cc_library( includes = ["include"], deps = [ ":error_util", + ":mlir_passthrough_op", ":tensorflow_canonicalize_inc_gen", ":tensorflow_device_ops_inc_gen", ":tensorflow_executor_inc_gen", ":tensorflow_ops_inc_gen", - ":tensorflow_optimize_inc_gen", "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/core:lib", "@llvm//:support", @@ -175,12 +160,74 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensorflow_passes", + srcs = [ + "transforms/bridge.cc", + "transforms/bridge_pass.cc", + "transforms/cluster_formation.cc", + "transforms/cluster_outlining.cc", + "transforms/executor_island_coarsening.cc", + "transforms/fold_switch.cc", + "transforms/functional_control_flow_to_cfg.cc", + "transforms/generated_canonicalize.inc", + "transforms/generated_optimize.inc", + "transforms/graph_pruning.cc", + "transforms/materialize_mlir_passthrough_op.cc", + "transforms/optimize.cc", + "transforms/raise_control_flow.cc", + "transforms/sink_constant.cc", + "transforms/tpu_cluster_formation.cc", + "transforms/tpu_rewrite_pass.cc", + "translate/control_to_executor_dialect.cc", + "translate/executor_to_control_dialect.cc", + ], + hdrs = [ + "transforms/bridge.h", + "transforms/passes.h", + ], + includes = ["include"], + deps = [ + ":error_util", + ":tensorflow", + ":tensorflow_optimize_inc_gen", + "//tensorflow/compiler/mlir/lite:validators", + "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", + "@llvm//:support", + "@local_config_mlir//:Analysis", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + ], + # TODO(jpienaar): Merge in the dialect registration. + alwayslink = 1, +) + +cc_library( + name = "tensorflow_test_passes", + srcs = [ + "transforms/lower_tf_pass.cc", + ], + deps = [ + ":lower_tf_lib", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + ], + alwayslink = 1, +) + # Library with TensorFlow dialect static initialization. cc_library( name = "tensorflow_dialect_registration", srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", + ":tensorflow_passes", "@local_config_mlir//:IR", ], alwayslink = 1, @@ -204,6 +251,7 @@ cc_library( ":mangling_util", ":mlir_roundtrip_flags", ":tensorflow", + ":tensorflow_passes", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/xla:status_macros", @@ -252,6 +300,8 @@ cc_library( "utils/export_utils.cc", ], hdrs = [ + "ir/tf_types.def", + "ir/tf_types.h", "utils/export_utils.h", ], deps = [ @@ -639,26 +689,73 @@ gentbl( ) cc_library( - name = "tensorflow_fold_switch", - srcs = [ - "transforms/fold_switch.cc", - ], - hdrs = [ - "transforms/passes.h", - ], - copts = ["-std=c++14"], + name = "compile_mlir_util", + srcs = ["utils/compile_mlir_util.cc"], + hdrs = ["utils/compile_mlir_util.h"], deps = [ - ":tensorflow", + ":convert_type", + ":error_util", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", "@llvm//:support", - "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", + "@local_config_mlir//:Parser", + ], +) + +tf_cc_test( + name = "compile_mlir_util_test", + size = "small", + srcs = ["utils/compile_mlir_util_test.cc"], + deps = [ + ":compile_mlir_util", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + ], +) + +cc_library( + name = "mlir_passthrough_op", + srcs = ["ops/mlir_passthrough_op.cc"], + deps = [ + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "gen_mlir_passthrough_op_py", + out = "gen_mlir_passthrough_op.py", + deps = [":mlir_passthrough_op"], +) + +# Library to get rewrite patterns lowering within TensorFlow. +# +# This is a separate library so that external passes can link only this library +# without linking any of the other tensorflow passes. +cc_library( + name = "lower_tf_lib", + srcs = [ + "transforms/lower_tf.cc", + ], + hdrs = [ + "transforms/lower_tf.h", + ], + deps = [ + ":tensorflow", + "@local_config_mlir//:IR", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 77d412f02c9..72799e19a0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -48,31 +49,65 @@ namespace { // If the given tensor has elements of type variant, then returns a new type // after dropping subtypes info. Otherwise, returns the original type as is. -Type DropVariantSubTypes(Type ty) { - ShapedType shaped_ty = ty.cast(); - Type element_ty = shaped_ty.getElementType(); +ShapedType DropVariantSubTypes(ShapedType ty) { + Type element_ty = ty.getElementType(); if (!element_ty.isa()) return ty; Type variant_ty = TF::VariantType::get(ty.getContext()); - if (shaped_ty.hasRank()) { - return RankedTensorType::get(shaped_ty.getShape(), variant_ty); + if (ty.hasRank()) { + return RankedTensorType::get(ty.getShape(), variant_ty); } return UnrankedTensorType::get(variant_ty); } +// If the given tensor has elements of type ref, then returns a new type +// of the shape, but corresponding non-ref type as element type. Otherwise, +// returns the original type as is. +ShapedType DropRefType(ShapedType type) { + Type element_ty = type.getElementType(); + TF::TensorFlowRefType ref_type = element_ty.dyn_cast(); + if (!ref_type) return type; + + if (type.hasRank()) { + return RankedTensorType::get(type.getShape(), ref_type.RemoveRef()); + } + return UnrankedTensorType::get(ref_type.RemoveRef()); +} + } // namespace //===----------------------------------------------------------------------===// // TF Executor Dialect //===----------------------------------------------------------------------===// +namespace { + +struct TensorFlowExecutorOpFolderDialectInterface + : public OpFolderDialectInterface { + using OpFolderDialectInterface::OpFolderDialectInterface; + + // Registered hook to check if the given region, which is attached to an + // operation that is *not* isolated from above (i.e. no internal regions + // reference values defined in an enclosing region), should be used when + // materializing constants. + // In the executor dialect we materialize inside an island. + bool shouldMaterializeInto(Region *region) const final { + return isa(region->getParentOp()); + } +}; + +} // namespace + TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) : Dialect(/*name=*/"tf_executor", context) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" >(); + + addInterfaces(); + addTypes(); } @@ -296,6 +331,23 @@ void Print(IslandOp op, OpAsmPrinter *p) { p->printOperands(op.getOperands()); *p << ')'; } + + // Check if we can print the short "wraps" form: that is if the island + // contains a single operation and the result of this operation are perfectly + // forwarded to the yield. + if (op.getAttrs().empty() && + std::next(op.GetBody().begin(), 2) == op.GetBody().end()) { + Operation &wrapped_op = op.GetBody().front(); + Operation &yield_op = op.GetBody().back(); + if (wrapped_op.getNumResults() == yield_op.getNumOperands() && + std::equal(wrapped_op.getResults().begin(), + wrapped_op.getResults().end(), + yield_op.getOperands().begin())) { + *p << " wraps "; + p->printGenericOp(&op.GetBody().front()); + return; + } + } p->printRegion(op.getOperation()->getRegion(0)); p->printOptionalAttrDict(op.getAttrs()); } @@ -316,17 +368,22 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) { // Parse the body region. Region &body = *result->addRegion(); - // TODO(b/134773778): the custom parser is missing support to implement to - // short syntax right now. - // if (!parser->parseOptionalKeyword("wraps")) { - // body.push_back(new Block); - // Block &block = body.back(); - // parser->getBuilder().setInsertionPointToEnd(&block); - // if (parser->parseOperation()) - // return failure(); - // } - - if (parser->parseRegion(body, llvm::None, llvm::None)) return failure(); + if (succeeded(parser->parseOptionalKeyword("wraps"))) { + // If we parse the short version of the island, we have an operation in the + // generic form that follows the "wraps" keyword. Parse it inside the region + // and forward all of its results as-is to the yield operation. + body.push_back(new Block); + Block &block = body.back(); + Operation *wrapped_op = + parser->parseGenericOperation(&block, block.begin()); + if (!wrapped_op) return failure(); + OpBuilder builder(parser->getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create(result->location, + llvm::to_vector<8>(wrapped_op->getResults())); + } else if (parser->parseRegion(body, llvm::None, llvm::None)) { + return failure(); + } IslandOp::ensureTerminator(body, parser->getBuilder(), result->location); @@ -536,35 +593,43 @@ LogicalResult Verify(MergeOp merge) { if (data_type.isa()) return merge.emitOpError() << "expects a non-control input"; - // Check that all operands can be broadcasted to a common type compatible with - // the result type. - Type broadcasted_type = merge.output()->getType(); + // Check that each operand can be individually broadcasted to the output type. + Type output_type = merge.output()->getType(); + TensorType output_tensor_ty = output_type.dyn_cast(); + if (!output_tensor_ty) { + return merge.emitOpError() + << "expects output to have tensor type but got " << output_type; + } + bool is_output_ref = + output_tensor_ty.getElementType().isa(); for (Type operand_type : merge.getOperandTypes()) { if (operand_type.isa()) break; // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this // constraint. - if (!operand_type.isa()) - return merge.emitOpError("expects data operands to have tensor type"); - - // Variant types may have opaque subtypes information that need not match - // between the two types so drop them before computing the broadcasted type. - Type new_broadcasted_type = - OpTrait::util::getBroadcastedType(DropVariantSubTypes(broadcasted_type), - DropVariantSubTypes(operand_type)); - if (!new_broadcasted_type) + TensorType operand_tensor_ty = operand_type.dyn_cast(); + if (!operand_tensor_ty) return merge.emitOpError() - << "expects all operands to be broadcastable" - << " but got " << broadcasted_type << " vs " << operand_type; - // Use the broadcasted type unless we're losing the rank information here. - // This is because for example starting with a result of tensor<4xf32>, if - // the first operand is unranked, the broadcasted type will be unranked. - // Then any tensor operand will be broadcastable to this unranked type. - if (!broadcasted_type.cast().hasRank() || - new_broadcasted_type.cast().hasRank()) - broadcasted_type = new_broadcasted_type; - } + << "expects data operands to have tensor type but got " + << operand_type; + // If output type is a ref type then all operand types should also be of the + // same ref type. However, if the output type is a non-ref type T, operands + // can be tensor of type T or T_REF. + if (is_output_ref && + !operand_tensor_ty.getElementType().isa()) { + return merge.emitOpError() + << "expects same operand and output element type but got " + << operand_tensor_ty << " vs " << output_tensor_ty; + } + Type broadcasted_type = OpTrait::util::getBroadcastedType( + DropRefType(DropVariantSubTypes(output_tensor_ty)), + DropRefType(DropVariantSubTypes(operand_tensor_ty))); + if (!broadcasted_type) + return merge.emitOpError() + << "expects all operands to be broadcastable with output type" + << " but got " << operand_tensor_ty << " vs " << output_tensor_ty; + } return success(); } @@ -1088,6 +1153,35 @@ void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results, DropEmptyIslandNoOperandOneDataResult>(context); } +//===----------------------------------------------------------------------===// +// tf_executor.ControlTrigger +//===----------------------------------------------------------------------===// + +namespace { +// This pattern matches and removes ControlTriggerOps with no control operands. +// Control result users will have their relevant operands removed. +struct DropEmptyControlTrigger : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ControlTriggerOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() != 0) return matchFailure(); + + for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) + use.getOwner()->eraseOperand(use.getOperandNumber()); + + rewriter.replaceOp(op, {nullptr}); + + return matchSuccess(); + } +}; +} // anonymous namespace + +void ControlTriggerOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Folders //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index e8843c7d64f..eb3b9797192 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -594,6 +594,8 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", let verifier = ?; + let hasCanonicalizer = 1; + let builders = [OpBuilder< "Builder *builder, OperationState *result, " "ArrayRef operands, ArrayRef attributes = {}", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 153ac5346b9..f01ff57c41d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -88,13 +88,13 @@ Inputs must be of same size and shape. }]; let arguments = (ins - Variadic>:$inputs, + Variadic>:$inputs, Confined]>:$N ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Variant]>:$sum + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -110,12 +110,12 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>, }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -123,6 +123,32 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>, let hasCanonicalizer = 1; } +def TF_AllOp : TF_Op<"All", [NoSideEffect]> { + let summary = [{ +Computes the "logical and" of elements across dimensions of a tensor. + }]; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + I1Tensor:$input, + TF_I32OrI64Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + I1Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; +} + def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. @@ -169,7 +195,7 @@ Usage: }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -202,7 +228,7 @@ Usage: }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -261,6 +287,88 @@ window in `value`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> { + let summary = "Multiplies slices of two tensors in batches."; + + let description = [{ +Multiplies all slices of `Tensor` `x` and `y` (each slice can be +viewed as an element of a batch), and arranges the individual results +in a single output tensor of the same batch size. Each of the +individual slices can optionally be adjointed (to adjoint a matrix +means to transpose and conjugate it) before multiplication by setting +the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + +The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +and `[..., r_y, c_y]`. + +The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + + r_o = c_x if adj_x else r_x + c_o = r_y if adj_y else c_y + +It is computed as: + + output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> { + let summary = "Multiplies slices of two tensors in batches."; + + let description = [{ +Multiplies all slices of `Tensor` `x` and `y` (each slice can be +viewed as an element of a batch), and arranges the individual results +in a single output tensor of the same batch size. Each of the +individual slices can optionally be adjointed (to adjoint a matrix +means to transpose and conjugate it) before multiplication by setting +the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + +The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +and `[..., r_y, c_y]`. + +The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + + r_o = c_x if adj_x else r_x + c_o = r_y if adj_y else c_y + +It is computed as: + + output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + +*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More +about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> { let summary = "BatchToSpace for N-D tensors of type T."; @@ -297,14 +405,14 @@ Broadcasting is supported, so `value` may have any number of dimensions. }]; let arguments = (ins - TF_NumberTensor:$value, - TF_NumberTensor:$bias, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias, DefaultValuedAttr:$data_format ); let results = (outs - TF_NumberTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -332,21 +440,23 @@ gives module error. For example, Example 1: -```python + >>> a = [1., 2., 3.] ->>> equality_bitcast = tf.bitcast(a,tf.complex128) -tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] ->>> equality_cast = tf.cast(a,tf.complex128) +>>> equality_bitcast = tf.bitcast(a, tf.complex128) +Traceback (most recent call last): +... +InvalidArgumentError: Cannot bitcast from 1 to 18 [Op:Bitcast] +>>> equality_cast = tf.cast(a, tf.complex128) >>> print(equality_cast) tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) -``` + Example 2: -```python + >>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) -``` + Example 3: -```python + >>> x = [1., 2., 3.] >>> y = [0., 2., 3.] >>> equality= tf.equal(x,y) @@ -358,10 +468,9 @@ tf.Tensor([False True True], shape=(3,), dtype=bool) tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) >>> print(equality_bitcast) tf.Tensor( -[[ 0 0 0 0] - [ 0 0 128 63] - [ 0 0 128 63]], shape=(3, 4), dtype=uint8) -``` + [[ 0 0 0 0] + [ 0 0 128 63] + [ 0 0 128 63]], shape=(3, 4), dtype=uint8) *NOTE*: Bitcast is implemented as a low-level cast, so machines with different endian orderings will give different results. @@ -393,14 +502,13 @@ and works its way forward. For example, -```python >>> x = tf.constant([1, 2, 3]) >>> y = tf.broadcast_to(x, [3, 3]) ->>> sess.run(y) -array([[1, 2, 3], - [1, 2, 3], - [1, 2, 3]], dtype=int32) -``` +>>> print(y) +tf.Tensor( + [[1 2 3] + [1 2 3] + [1 2 3]], shape=(3, 3), dtype=int32) In the above example, the input Tensor with the shape of `[1, 3]` is broadcasted to output Tensor with shape of `[3, 3]`. @@ -462,6 +570,27 @@ def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [SameOperandsAndResultType]> { + let summary = "Checks a tensor for NaN and Inf values."; + + let description = [{ +When run, reports an `InvalidArgument` error if `tensor` has any values +that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. + }]; + + let arguments = (ins + TF_FpTensor:$tensor, + + StrAttr:$message + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; @@ -480,6 +609,10 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> { ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { @@ -501,6 +634,10 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_ConjOp : TF_Op<"Conj", [NoSideEffect]> { @@ -771,12 +908,12 @@ def TF_DivOp : TF_Op<"Div", [Broadcastable, NoSideEffect]>, }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -805,8 +942,7 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_EqualOp : TF_Op<"Equal", [Broadcastable, Commutative, NoSideEffect]>, - WithBroadcastableCmpOpBuilder { +def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -825,8 +961,10 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y, + + DefaultValuedAttr:$incompatible_shape_error ); let results = (outs @@ -834,6 +972,15 @@ tf.math.equal(x, y) ==> array([True, True]) ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let builders = [ + OpBuilder<"Builder* builder, OperationState* result, Value* x, " + "Value* y, BoolAttr incompatible_shape_error"> + ]; + + let verifier = [{ + return Verify(*this); + }]; } def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> { @@ -1017,6 +1164,52 @@ values. }]; } +def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> { + let summary = [{ +Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, + }]; + + let description = [{ +`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` +to 'outputs' tensor of same shape as `inputs`. + +`[min; max]` define the clamping range for the `inputs` data. +`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +then de-quantized and output as floats in `[min; max]` interval. +`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + +Before quantization, `min` and `max` values are adjusted with the following +logic. +It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +the behavior can be unexpected: +If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + +This operation has a gradient and thus allows for training `min` and `max` +values. + }]; + + let arguments = (ins + F32Tensor:$inputs, + F32Tensor:$min, + F32Tensor:$max, + + DefaultValuedAttr:$num_bits, + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + F32Tensor:$outputs + ); + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_FillOp : TF_Op<"Fill", [NoSideEffect]> { let summary = "Creates a tensor filled with a scalar value."; @@ -1082,12 +1275,12 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [Broadcastable, NoSideEffect]>, }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1780,14 +1973,14 @@ retained with length 1. }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TF_NumberTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1801,7 +1994,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> { }]; let arguments = (ins - TF_IntOrFpTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input, Confined]>:$ksize, Confined]>:$strides, @@ -1810,7 +2003,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> { ); let results = (outs - TF_IntOrFpTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1850,14 +2043,14 @@ retained with length 1. }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TF_NumberTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1931,6 +2124,57 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>; } +def TF_MlirPassthroughOpOp : TF_Op<"MlirPassthroughOp", [NoSideEffect]> { + let summary = [{ +Wraps an arbitrary MLIR computation expressed as a module with a main() function. + }]; + + let description = [{ +This operation does not have an associated kernel and is not intended to be +executed in a regular TensorFlow session. Instead it is intended to be used for +testing or for special case where a user intends to pass custom MLIR computation +through a TensorFlow graph with the intent of having custom tooling processing +it downstream (when targeting a different environment, like TensorFlow lite for +example). +The MLIR module is expected to have a main() function that will be used as an +entry point. The inputs to the operations will be passed as argument to the +main() function and the returned values of the main function mapped to the +outputs. +Example usage: + +``` +import tensorflow as tf +from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op + +mlir_module = ''' +func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> { + %add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32> + return %ret : tensor<10x10xf32> +} +''' + +@tf.function +def foo(x, y): + return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) + +graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def() +``` + }]; + + let arguments = (ins + Variadic:$inputs, + + StrAttr:$mlir_module + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -1941,12 +2185,12 @@ def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>, }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2006,8 +2250,101 @@ def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let results = (outs); } -def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect]>, - WithBroadcastableCmpOpBuilder { +def TF_NonMaxSuppressionV4Op : TF_Op<"NonMaxSuppressionV4", [NoSideEffect]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$boxes, + TensorOf<[F16, F32]>:$scores, + I32Tensor:$max_output_size, + TensorOf<[F16, F32]>:$iou_threshold, + TensorOf<[F16, F32]>:$score_threshold, + + DefaultValuedAttr:$pad_to_max_output_size + ); + + let results = (outs + I32Tensor:$selected_indices, + I32Tensor:$valid_outputs + ); + + TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_NonMaxSuppressionV5Op : TF_Op<"NonMaxSuppressionV5", [NoSideEffect]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) +This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f. +Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score +of other overlapping boxes instead of directly causing them to be pruned. +To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be +larger than 0. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$boxes, + TensorOf<[F16, F32]>:$scores, + I32Tensor:$max_output_size, + TensorOf<[F16, F32]>:$iou_threshold, + TensorOf<[F16, F32]>:$score_threshold, + TensorOf<[F16, F32]>:$soft_nms_sigma, + + DefaultValuedAttr:$pad_to_max_output_size + ); + + let results = (outs + I32Tensor:$selected_indices, + TensorOf<[F16, F32]>:$selected_scores, + I32Tensor:$valid_outputs + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x != y) element-wise."; let description = [{ @@ -2016,8 +2353,10 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect] }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$y + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y, + + DefaultValuedAttr:$incompatible_shape_error ); let results = (outs @@ -2025,6 +2364,15 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Broadcastable, Commutative, NoSideEffect] ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let builders = [ + OpBuilder<"Builder* builder, OperationState* result, Value* x, " + "Value* y, BoolAttr incompatible_shape_error"> + ]; + + let verifier = [{ + return Verify(*this); + }]; } def TF_OneHotOp : TF_Op<"OneHot", [NoSideEffect]> { @@ -2121,7 +2469,7 @@ output = }]; let arguments = (ins - TensorOf<[I32, I64, I8]>:$indices, + TensorOf<[I32, I64, TF_Uint8]>:$indices, I32Tensor:$depth, TF_Tensor:$on_value, TF_Tensor:$off_value, @@ -2176,6 +2524,10 @@ This is the opposite of `unpack`. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_PadOp : TF_Op<"Pad", [NoSideEffect]> { @@ -2303,14 +2655,14 @@ retained with length 1. }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TF_NumberTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2406,7 +2758,8 @@ The above round function rounds the value based on the given round_mode. DefaultValuedAttr:$num_bits, DefaultValuedAttr:$range_given, DefaultValuedAttr, "HALF_TO_EVEN">:$round_mode, - DefaultValuedAttr:$narrow_range + DefaultValuedAttr:$narrow_range, + DefaultValuedAttr:$axis ); let results = (outs @@ -2432,7 +2785,8 @@ tensor, so its value can change during training. DefaultValuedAttr:$signed_input, DefaultValuedAttr:$range_given, - DefaultValuedAttr:$narrow_range + DefaultValuedAttr:$narrow_range, + DefaultValuedAttr:$axis ); let results = (outs @@ -2550,12 +2904,12 @@ If `x` and `y` are reals, this will return the floating-point division. }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2590,11 +2944,11 @@ def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> { }]; let arguments = (ins - TF_IntOrFpTensor:$features + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$features ); let results = (outs - TF_IntOrFpTensor:$activations + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$activations ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2709,7 +3063,7 @@ Input images can be of different types but output images are always float. }]; let arguments = (ins - TF_IntOrFpTensor:$images, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, I32Tensor:$size, DefaultValuedAttr:$align_corners, @@ -2732,7 +3086,7 @@ Resize `images` to `size` using nearest neighbor interpolation. }]; let arguments = (ins - TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$images, + TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, I32Tensor:$size, DefaultValuedAttr:$align_corners, @@ -2740,7 +3094,7 @@ Resize `images` to `size` using nearest neighbor interpolation. ); let results = (outs - TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$resized_images + TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2875,12 +3229,12 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11], }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$tensor, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$tensor, TF_I32OrI64Tensor:$axis ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str]>:$output + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Str, TF_Uint16, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3031,6 +3385,10 @@ shape(t) ==> [2, 2, 3] return Verify(*this); }]; + let builders = [ + OpBuilder<"Builder* builder, OperationState* result, Value* input, BoolAttr use32Bit"> + ]; + let hasFolder = 1; } @@ -3643,12 +4001,12 @@ def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>, }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3667,20 +4025,36 @@ retained with length 1. }]; let arguments = (ins - TF_NumberTensor:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$reduction_indices, DefaultValuedAttr:$keep_dims ); let results = (outs - TF_NumberTensor:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> { + let summary = "Returns the result of a TPU compilation."; + + let description = [{ +This operation returns the result of a TPU compilation as a serialized +CompilationResultProto, which holds a status and an error message if an error +occurred during compilation. + }]; + + let arguments = (ins); + + let results = (outs + TF_StrTensor:$output + ); +} + def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes hyperbolic tangent of `x` element-wise."; @@ -3750,6 +4124,23 @@ def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> { TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_TensorListLengthOp : TF_Op<"TensorListLength", [NoSideEffect]> { + let summary = "Returns the number of tensors in the input tensor list."; + + let description = [{ +input_handle: the input list +length: the number of tensors in the list + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle + ); + + let results = (outs + I32Tensor:$length + ); +} + def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> { let summary = [{ Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. @@ -3921,12 +4312,12 @@ Python Semantics. }]; let arguments = (ins - TF_NumberTensor:$x, - TF_NumberTensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y ); let results = (outs - TF_NumberTensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4068,7 +4459,7 @@ where(input) ==> [[0, 0, 0], }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input ); let results = (outs diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 080e78042a7..2a3f984d3d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -178,7 +178,8 @@ def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], def TF_NumberTensor : TensorOf<[TF_AnyNumber]>; -def TF_NumberOrStrTensor : TensorOf<[TF_AnyNumber, TF_Str]>; +def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>; +def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>; //===----------------------------------------------------------------------===// // TensorFlow attribute definitions diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 39e3bf08553..8d28ec26507 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -19,13 +19,16 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/Traits.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir @@ -34,6 +37,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir @@ -71,10 +75,19 @@ static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) { // Returns true if the given `value` has at least the specified rank or has // unranked type. static inline bool HasRankAtLeast(Value *value, int64_t rank) { - auto type = value->getType(); + Type type = value->getType(); if (auto ranked_type = type.dyn_cast()) return ranked_type.getRank() >= rank; - return type.isa(); + return true; +} + +// Returns true if the given `value` has at most the specified rank or has +// unranked type. +static inline bool HasRankAtMost(Value *value, int64_t rank) { + Type type = value->getType(); + if (auto ranked_type = type.dyn_cast()) + return ranked_type.getRank() <= rank; + return true; } // Returns true if the given pair of TensorFlow types can be cast to one @@ -95,6 +108,85 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } +// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If +// `incompatible_shape_error` is true, reports error if `x` and `y` has +// incompatible shapes. Otherwise, returns a tensor type with unknown rank. +static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value *x, + Value *y, BoolAttr incompatible_shape_error) { + auto result_type = + OpTrait::util::getBroadcastedType(x->getType(), y->getType()); + if (!result_type) { + if (incompatible_shape_error.getValue()) { + mlir::emitError(loc, "non-broadcastable operands"); + } else { + result_type = builder->getTensorType(builder->getI1Type()); + } + } + return result_type; +} + +// Verifies that the given types are cast compatible. If not, emits appropriate +// error for the given op. If mask_one_dim is set to true, then the types are +// allowed to have one mismatching dimension. Masking one of the dimensions is +// useful for ops like Concat that requires all ranked inputs to have the same +// rank and match dimension sizes for all but one of the dimensions. +static LogicalResult VerifyTypesCompatibility( + Operation::operand_type_range types, bool mask_one_dim, Operation *op) { + constexpr int64_t kUninitialized = -1; + int64_t common_rank = kUninitialized; + llvm::SmallVector common_dims; + int64_t dim_to_mask = kUninitialized; + + // Initialize common_rank with rank of the first ranked type and verify that + // following ranked types have the same rank. + // Similarly, initialize each of the dimensions with the first type that has + // the dimension size available and verify that all following types have the + // same size for the dimension. However, if mask_one_dim is true, note down + // the dimension index on the first mismatch and ignore dimension at that + // index in following types. + for (Type ty : types) { + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) continue; + + int64_t rank = ranked_ty.getRank(); + if (common_rank == kUninitialized) { + common_rank = rank; + common_dims.resize(common_rank, kUninitialized); + } else if (common_rank != rank) { + return op->emitError() + << "operand type " << ranked_ty + << " is not compatible with preceding operands; expected rank: " + << common_rank; + } + + for (int64_t i = 0, e = common_rank; i != e; i++) { + if (i == dim_to_mask) continue; + + int64_t dim = ranked_ty.getDimSize(i); + if (dim == kUninitialized) continue; + + int64_t &common_dim = common_dims[i]; + if (common_dim == kUninitialized) { + common_dim = dim; + } else if (common_dim != dim) { + // If mask_one_dim is true, do not emit an error if this is the only + // dimension with mismatches. Note down the dimension to mask it from + // the following types. + if (mask_one_dim && dim_to_mask == kUninitialized) { + dim_to_mask = i; + continue; + } + + return op->emitError() << "operand type " << ranked_ty + << " is not compatible with preceding operands; " + "expected dimension at index " + << i << ": " << common_dim; + } + } + } + return success(); +} + namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -176,6 +268,36 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// ConcatOp and ConcatV2Op +//===----------------------------------------------------------------------===// + +template ::value>> +static LogicalResult Verify(OpT op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + auto num_values = std::distance(values.begin(), values.end()); + int64_t attr_N = op.N().getLimitedValue(); + if (num_values != attr_N) { + return op.emitOpError() + << "requires attribute 'N' to match the number of inputs; expected: " + << num_values << " Found: " << attr_N; + } + + int axis_idx = std::is_same() ? 0 : 1; + Value *axis = *op.getODSOperands(axis_idx).begin(); + if (!HasRankAtMost(axis, 1)) { + return op.emitOpError( + "requires axis to be of scalar type (or vector type for older " + "versions)"); + } + + return VerifyTypesCompatibility(values, + /*mask_one_dim=*/true, op.getOperation()); +} + //===----------------------------------------------------------------------===// // ConjOp //===----------------------------------------------------------------------===// @@ -257,6 +379,26 @@ static LogicalResult Verify(EmptyTensorListOp op) { return success(); } +//===----------------------------------------------------------------------===// +// EqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void EqualOp::build(Builder *builder, OperationState *result, Value *x, + Value *y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + //===----------------------------------------------------------------------===// // FakeQuantWithMinMaxArgsOp //===----------------------------------------------------------------------===// @@ -276,12 +418,6 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + "," + Twine(std::to_string(rmax)) + "]"); } - // Range must straddle zero. - if (rmin > 0.0 || rmax < 0.0) { - return op.emitOpError("range failed to straddle zero: [" + - Twine(std::to_string(rmin)) + "," + - Twine(std::to_string(rmax)) + "]"); - } int64_t num_bits = op.num_bits().getSExtValue(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( @@ -308,6 +444,37 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { return success(); } +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxVarsPerChannelOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { + if (!isOfRankedFloatTensorType(op.min(), 1)) + return op.emitOpError("requires min to be a 1d float tensor"); + + if (!isOfRankedFloatTensorType(op.max(), 1)) + return op.emitOpError("requires max to be a 1d float tensor"); + + Value *inputs = op.inputs(); + if (!HasRankAtLeast(inputs, 1) || + inputs->getType().isa()) { + return op.emitError("requires inputs to be at least 1d float tensor"); + } + + auto inputsType = inputs->getType().cast(); + int depth = inputsType.getDimSize(inputsType.getRank() - 1); + if (op.min()->getType().cast().getDimSize(0) != depth || + op.max()->getType().cast().getDimSize(0) != depth) { + return op.emitOpError( + "requires min and max to have same size as last dimension of inputs"); + } + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + return success(); +} + //===----------------------------------------------------------------------===// // FusedBatchNormOp //===----------------------------------------------------------------------===// @@ -471,6 +638,74 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// NotEqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(NotEqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void NotEqualOp::build(Builder *builder, OperationState *result, Value *x, + Value *y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(builder, result->location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(PackOp op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + auto num_values = std::distance(values.begin(), values.end()); + int64_t attr_N = op.N().getLimitedValue(); + if (num_values != attr_N) { + return op.emitOpError() + << "requires attribute 'N' to match the number of inputs; expected: " + << num_values << " Found: " << attr_N; + } + + if (failed(VerifyTypesCompatibility(values, + /*mask_one_dim=*/false, + op.getOperation()))) { + return failure(); + } + + int64_t inputs_rank = -1; + for (Value *value : values) { + if (auto ty = value->getType().dyn_cast()) { + // Exit early as input types are verified to be compatible so all ranked + // tensors have the same rank. + inputs_rank = ty.getRank(); + break; + } + } + if (inputs_rank == -1) return success(); + + // The values can be packed along any of the dimensions between 0 and + // inputs rank, inclusive. Also, as the negative axis values wrap around so + // the axis value range is [-(R+1), R+1). + int64_t range_begin = -inputs_rank - 1; // Inclusive + int64_t range_end = inputs_rank + 1; // Exclusive + int64_t axis = op.axis().getLimitedValue(); + if (axis < range_begin || axis >= range_end) { + return op.emitError() << "attribute 'axis' should be within range [" + << range_begin << ", " << range_end + << "); actual value: " << axis; + } + + return success(); +} + //===----------------------------------------------------------------------===// // ReciprocalOp //===----------------------------------------------------------------------===// @@ -731,6 +966,16 @@ OpFoldResult ShapeOp::fold(ArrayRef operands) { return b.getDenseElementsAttr(resultType, dimensions); } +void ShapeOp::build(Builder *builder, OperationState *result, Value *input, + BoolAttr use32Bit) { + auto rankedTensorType = input->getType().dyn_cast(); + int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; + auto out_type = use32Bit.getValue() ? builder->getIntegerType(32) + : builder->getIntegerType(64); + return ShapeOp::build(builder, result, + builder->getTensorType({rank}, out_type), input); +} + //===----------------------------------------------------------------------===// // ShapeNOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc new file mode 100644 index 00000000000..fe9bfcccba7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ops/mlir_passthrough_op.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("MlirPassthroughOp") + .Attr("mlir_module: string") + .Attr("Tinputs : list(type) >= 0") + .Input("inputs: Tinputs") + .Attr("Toutputs : list(type) >= 0") + .Output("outputs: Toutputs") + .Doc(R"doc( +Wraps an arbitrary MLIR computation expressed as a module with a main() function. + +This operation does not have an associated kernel and is not intended to be +executed in a regular TensorFlow session. Instead it is intended to be used for +testing or for special case where a user intends to pass custom MLIR computation +through a TensorFlow graph with the intent of having custom tooling processing +it downstream (when targeting a different environment, like TensorFlow lite for +example). +The MLIR module is expected to have a main() function that will be used as an +entry point. The inputs to the operations will be passed as argument to the +main() function and the returned values of the main function mapped to the +outputs. +Example usage: + +``` +import tensorflow as tf +from tensorflow.compiler.mlir.tensorflow.gen_mlir_passthrough_op import mlir_passthrough_op + +mlir_module = ''' +func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> { + %add = "magic.op"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32> + return %ret : tensor<10x10xf32> +} +''' + +@tf.function +def foo(x, y): + return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) + +graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def() +``` +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 65feaa8b84c..a15d543825d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -236,8 +236,8 @@ func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> return %1: tensor<8x16xi1> -// CHECK: %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> -// CHECK: return %0 +// CHECK: %[[NE:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} +// CHECK: return %[[NE]] } // CHECK-LABEL: testLogicalNotOfNotEqual @@ -246,8 +246,8 @@ func @testLogicalNotOfNotEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> return %1: tensor<8x16xi1> -// CHECK: %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> -// CHECK: return %0 +// CHECK: %[[NE:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} +// CHECK: return %[[NE]] } // CHECK-LABEL: testLogicalNotOfGreater diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir index 9e2fdcc1ee5..caf6e73b98f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s +// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s -dump-input-on-failure // Simple case, single device cluster. @@ -72,11 +72,8 @@ module { // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @argliveinotherislands(%arg0: tensor) -> tensor { %0 = tf_executor.graph { - // CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island { - %1:2 = tf_executor.island { - %3 = "tf.D"(%arg0) : (tensor) -> tensor - tf_executor.yield %3 : tensor - } + // CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island wraps "tf.D" + %1:2 = tf_executor.island wraps "tf.D"(%arg0) : (tensor) -> tensor %2:2 = tf_executor.island { // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index 30272b443a1..1d0e2b245bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -90,16 +90,13 @@ module { // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @multiplelaunches(%arg0: tensor) -> tensor { %0 = tf_executor.graph { - %1:2 = tf_executor.island { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func} - %2 = "tf_device.launch"() ( { + %1:2 = tf_executor.island wraps + // CHECK: %[[A_OUTPUT:[0-9]*]]:2 = {{.*}} "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func} + "tf_device.launch"() ( { %3 = "tf.A"() : () -> tensor "tf_device.return"(%3) : (tensor) -> () }) {device = "tpu0"} : () -> tensor - - // CHECK: tf_executor.yield %[[A_OUTPUT]] - tf_executor.yield %2 : tensor - } + // CHECK: tf_executor.fetch %[[A_OUTPUT]]#0 tf_executor.fetch %1#0 : tensor } return %0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir index 48f4c8f77df..25adff97d48 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir @@ -11,14 +11,8 @@ func @islands_with_control(tensor<*xf32>) -> tensor<*xf32> { } // CHECK-NEXT: %[[GRAPH:[0-9]*]] = tf_executor.graph { -// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %{{[0-9]*}} = "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) { -// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: tf_executor.fetch %[[ADD]]#0 : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: return %[[GRAPH]] : tensor<*xf32> @@ -45,40 +39,19 @@ func @LoopTest() { } // CHECK-NEXT: tf_executor.graph { -// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> tensor -// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor -// CHECK-NEXT: } +// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> tensor // CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[CONST]]#0 frame "while/while_context" : (tensor) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"} -// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island { -// CHECK-NEXT: "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> () -// CHECK-NEXT: tf_executor.yield -// CHECK-NEXT: } +// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island wraps "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> () // CHECK-NEXT: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} // CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} -// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) { -// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor} : () -> tensor -// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor -// CHECK-NEXT: } -// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor) -> tensor<*xi1> -// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi1> -// CHECK-NEXT: } +// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor} : () -> tensor +// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island wraps "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor) -> tensor<*xi1> // CHECK-NEXT: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[LESS:[0-9]*]]#0 : (tensor<*xi1>) -> (tensor, !tf_executor.control) {device = "", name = "while/LoopCond"} // CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} // CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} -// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32> -// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi32> -// CHECK-NEXT: } -// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) { -// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor} : () -> tensor -// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor -// CHECK-NEXT: } -// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> tensor<*xi32> -// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32> +// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor} : () -> tensor +// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island wraps "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> tensor<*xi32> // CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} // CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} // CHECK-NEXT: tf_executor.fetch diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir index 5b4e8e16cbb..19ce07db947 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir @@ -285,9 +285,9 @@ func @empty_island_no_operand_no_data_result() { return } -// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island { +// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island // CHECK-NEXT: "tf.opA" -// CHECK: tf_executor.island(%[[ISLAND_0]]) { +// CHECK: tf_executor.island(%[[ISLAND_0]]) // CHECK-NEXT: "tf.opB" // CHECK-NOT: tf_executor.island @@ -313,9 +313,9 @@ func @empty_island_one_operand_no_data_result() { return } -// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island { +// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island // CHECK-NEXT: "tf.opA" -// CHECK: tf_executor.island(%[[ISLAND_1]]) { +// CHECK: tf_executor.island(%[[ISLAND_1]]) // CHECK-NEXT: "tf.opB" // CHECK-NOT: tf_executor.island @@ -342,8 +342,34 @@ func @empty_island_no_operand_one_data_no_control_result(%arg0 : tensor) { return } -// CHECK: tf_executor.island { +// CHECK: tf_executor.island // CHECK-NEXT: "tf.opA"(%[[ARG_0]]) // CHECK: tf_executor.island { // CHECK-NEXT: "tf.opB"(%[[ARG_0]]) // CHECK-NOT: tf_executor.island + + +// Test empty control trigger with no operands is removed. +// Control result users should also have their respective operands removed. +// CHECK-LABEL: func @empty_control_trigger +func @empty_control_trigger() { + tf_executor.graph { + %0 = tf_executor.ControlTrigger {} + %1 = tf_executor.island(%0) { + %3 = "tf.opA"() : () -> tensor + tf_executor.yield + } + %2 = tf_executor.island(%0, %1) { + %4 = "tf.opB"() : () -> tensor + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island +// CHECK-NEXT: "tf.opA" +// CHECK: tf_executor.island(%[[ISLAND_0]]) +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: tf_executor.island diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir index a9e83dd006c..35dc4caba90 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -89,9 +89,7 @@ func @empty_islands(%arg0 : tensor, %arg1 : tensor) -> (tensor, tens return %0#0, %0#1 : tensor, tensor } -// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island { -// CHECK-NEXT: %[[OP_A:[0-9]*]]:2 = "tf.opA"(%[[ARG_1]], %[[ARG_0]]) -// CHECK-NEXT: tf_executor.yield %[[OP_A]]#0, %[[OP_A]]#1 : tensor, tensor +// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island wraps "tf.opA"(%[[ARG_1]], %[[ARG_0]]) // CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor, tensor @@ -228,9 +226,7 @@ func @islands_interleaved(%arg0 : tensor, %arg1 : tensor) -> (tensor -// CHECK: tf_executor.island { -// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]]) -// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor +// CHECK: tf_executor.island wraps "tf.opF"(%[[ARG_1]]) // CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor, tensor @@ -279,13 +275,9 @@ func @merge_islands_only() { return } -// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %[[OP_A:.*]] = "tf.opA" -// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor +// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.opA" // CHECK: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[ISLAND_0]]#0 -// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island { -// CHECK-NEXT: "tf.opB"() -// CHECK-NEXT: tf_executor.yield +// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island wraps "tf.opB"() // CHECK: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source // CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 // CHECK-NEXT: %[[ISLAND_2:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) { @@ -322,9 +314,7 @@ func @simple_potential_cycle() { return } -// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { -// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA" -// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32> +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.opA" // CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1 // CHECK-NEXT: tf_executor.island(%[[CT]]) { // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB" @@ -384,9 +374,7 @@ func @merge_into_nested_data_result() { // CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA" // CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph { // CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger -// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) { -// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) -// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32> +// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) wraps "tf.opB"(%[[OP_A]]) // CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32> // CHECK: tf_executor.yield @@ -422,18 +410,14 @@ func @merge_islands_inner_graph() { return } -// CHECK: tf_executor.island { -// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA" -// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32> -// CHECK: tf_executor.island { -// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph { +// CHECK: tf_executor.island wraps "tf.opA" +// CHECK: tf_executor.island wraps "tf_executor.graph"() ( { // CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island { // CHECK-NEXT: "tf.opB" // CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC" // CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]]) // CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32> // CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32> -// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32> // Test merging islands with control island operands and island results only if @@ -454,7 +438,7 @@ func @merge_islands_closest_control() { return } -// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island { +// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island // CHECK: tf_executor.ControlTrigger %[[ISLAND]] // CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) { +// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir new file mode 100644 index 00000000000..49247dede30 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir @@ -0,0 +1,23 @@ +// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input=fail + +// Test that a constant stays inside an island after canonicalization + +// CHECK-LABEL: func @constant_in_island +func @constant_in_island(%arg0 : tensor) -> tensor { + %0 = tf_executor.graph { +// CHECK: tf_executor.island +// CHECK: tf.Const{{.*}}2.0 + %1:2 = tf_executor.island { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + tf_executor.yield %0 : tensor + } +// Uses two islands for no other reason than preventing canonicalization from +// eliminating the graph entirely. + %2:2 = tf_executor.island(%1#1) { + %4 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %2#0 : tensor + } + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir index 11b9b1a564d..5ff18c3cae3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir @@ -97,3 +97,24 @@ func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { } return %fetches : tensor<*xf32> } + +// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect. +// CHECK-LABEL: func @ref_tf_executor_ops +func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor, %arg4: tensor ) -> tensor<4x!tf.f32ref> { + %result = tf_executor.graph { + // CHECK: _tf.Enter + %0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control) + // CHECK: _tf.Exit + %1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref> + // CHECK: _tf.Switch + %2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control) + // CHECK: _tf.Merge + %3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor, !tf_executor.control) + // CHECK: _tf.NextIteration.source + %4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref> + // CHECK: _tf.NextIteration.sink + tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref> + tf_executor.fetch %0#0 : tensor<4x!tf.f32ref> + } + return %result : tensor<4x!tf.f32ref> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt index a2b9efff36b..15289bf47ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt @@ -39,13 +39,10 @@ versions { # CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> # CHECK: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "Add"}} { -# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island -# CHECK-NEXT: "tf.Placeholder.input"(%arg0) +# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) -# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island -# CHECK-NEXT: "tf.Placeholder.input"(%arg1) +# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1) -# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island -# CHECK-NEXT: "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0) +# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island wraps "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0) # CHECK: fetch %[[add]]#0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt index 74adc38d87d..7b3462f37cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt @@ -41,8 +41,7 @@ library { } # Drop the control dependency on arg for the node "test" # CHECK-LABEL: func @foo - # CHECK: tf_executor.island { - # CHECK-NEXT: "tf.Const"() + # CHECK: tf_executor.island wraps "tf.Const"() node_def { name: "test" op: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt index 1bf5037a75f..5c4c23a67db 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -6,12 +6,9 @@ # CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> (tensor, tensor) # CHECK: attributes {tf.entry_function = {inputs = "args_0, args_1", outputs = "rets_0_RetVal, rets_1_RetVal"}} { -# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island { -# CHECK: "tf.Const" -# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island { -# CHECK: "tf.Identity"(%[[ISLAND_0]]#0) -# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island { -# CHECK: "tf.StatefulPartitionedCall" +# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island wraps "tf.Const" +# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island wraps "tf.Identity"(%[[ISLAND_0]]#0) +# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall" # CHECK-SAME: f = @[[FUNC:[a-z0-9]*]] # CHECK: tf_executor.fetch %[[ISLAND_1]]#0, %[[ISLAND_2]]#0 : tensor, tensor # CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt index 9238ea92a20..fa095a19eff 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt @@ -5,7 +5,7 @@ # FetchOp. # Match the island containing the "tf.Neg", capture the output -# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg" +# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg" # Check that the tf.Neg control is passed to the fetch # CHECK: tf_executor.fetch {{.*}} %[[ISLAND_0]]#1 : tensor<*xf32>, !tf_executor.control diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt index adad8b109b6..dbb1d14e331 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt @@ -5,7 +5,7 @@ # FetchOp. # Match the island containing the "tf.Neg", capture the output -# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg" +# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island wraps "tf.Neg" # Check that the tf.Neg data output and control are passed to the fetch # CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<*xf32>, !tf_executor.control diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt new file mode 100644 index 00000000000..fc27e82d20e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt @@ -0,0 +1,110 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s + +# Verify that the _input_shapes attribute of the FunctionDef is respected. +# This also checks that the output type is correctly inferred based on +# that. +#CHECK: func @identity_function0(%arg0: tensor) -> tensor + +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + experimental_debug_info { + } +} +node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + experimental_debug_info { + } +} +node { + name: "If" + op: "If" + input: "Placeholder" + input: "Placeholder_1" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "else_branch" + value { + func { + name: "identity_function" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "identity_function" + } + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "identity_function" + input_arg { + name: "identity_input" + type: DT_INT32 + } + output_arg { + name: "identity_output" + type: DT_INT32 + } + } + ret { + key: "identity_output" + value: "identity_input" + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + } + } + } + } +} +versions { + producer: 29 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt new file mode 100644 index 00000000000..e75fe8c9d67 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt @@ -0,0 +1,177 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s + +# Verify that the _output_shapes attribute of ReadVariableOp's are used to get +# variable types. +# This also checks that the output type is correctly inferred based on +# that. +# CHECK: func @__inference_some_function_130(%arg0: tensor<*x!tf.resource>) -> tensor +# CHECK: tf.ReadVariableOp"(%arg0) {{.*}} : (tensor<*x!tf.resource>) -> tensor + + +node { + name : "Variable" + op : "VarHandleOp" + attr { + key : "shape" + value { + shape { + } + } + } + attr { + key : "dtype" + value { + type : DT_FLOAT + } + } + attr { + key : "shared_name" + value { + s: "Variable" + } + } + attr { + key : "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name : "StatefulPartitionedCall" + op : "StatefulPartitionedCall" + input : [ "Variable" ] + attr { + key : "f" + value { + func { + name: "__inference_some_function_13" + } + } + } + attr { + key : "config_proto" + value { + s: "\n\x07\n\x03GPU\x10\x00\n\x07\n\x03\x43PU\x10\x01\x32\x02J\x00\x38\x01" + } + } + attr { + key : "Tout" + value { + list { + type : [ DT_FLOAT ] + } + } + } + attr { + key : "_gradient_op_type" + value { + s: "PartitionedCall-29" + } + } + attr { + key : "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key : "Tin" + value { + list { + type : [ DT_RESOURCE ] + } + } + } +} +library { + function { + signature { + name: "__inference_some_function_13" + input_arg { + name : "readvariableop_resource" + type : DT_RESOURCE + } + output_arg { + name : "identity" + type : DT_FLOAT + } + is_stateful : true + control_output: [ "ReadVariableOp" ] + } + node_def { + name : "ReadVariableOp" + op : "ReadVariableOp" + input : [ "readvariableop_resource" ] + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key : "dtype" + value { + type : DT_FLOAT + } + } + attr { + key : "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name : "Identity" + op : "Identity" + input : [ "ReadVariableOp:value:0", "^ReadVariableOp" ] + attr { + key : "T" + value { + type : DT_FLOAT + } + } + attr { + key : "_output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key : "identity" + value: "Identity:output:0" + } + attr { + key : "_input_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + control_ret { + key : "ReadVariableOp" + value: "ReadVariableOp" + } + arg_attr { + key : 0x00000000 + value { + } + } + } +} +versions { + producer : 148 + min_consumer : 12 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt index 37f7a876814..be059e0b2d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt @@ -7,8 +7,7 @@ # CHECK: "tf.Placeholder.input"(%arg0) # CHECK: tf.Relu -# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island -# CHECK-NEXT: tf.Identity +# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island wraps "tf.IdentityN" # CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor, tensor node { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt new file mode 100644 index 00000000000..1df903d46ce --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt @@ -0,0 +1,101 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s + +# CHECK:"tf.MlirPassthroughOp" +# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + +node { + name: "x" + op: "Placeholder" + attr { + key: "_user_specified_name" + value { + s: "x" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } +} +node { + name: "y" + op: "Placeholder" + attr { + key: "_user_specified_name" + value { + s: "y" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } +} +node { + name: "MlirPassthroughOp" + op: "MlirPassthroughOp" + input: "x" + input: "y" + attr { + key: "Tinputs" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Toutputs" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "mlir_module" + value { + s: "\nfunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\n %add = \"tf.Add\"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\n %ret = \"magic.op\"(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\n return %ret : tensor<10x10xf32>\n}\n" + } + } +} +node { + name: "Identity" + op: "Identity" + input: "MlirPassthroughOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 148 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir index 4566ffb507c..ac6838c9d58 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir @@ -13,15 +13,12 @@ func @foo(%arg0: tensor) -> tensor { // The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops. // Capture the result of input to function call. -// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island -// CHECK-NEXT: "tf.VarHandleOp"() +// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.VarHandleOp"() // Test for the presence of Identity op between input and function call. -// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island -// CHECK-NEXT: "tf.Identity"([[VARIABLE_REG]]#0) +// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.Identity"([[VARIABLE_REG]]#0) -// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island -// CHECK-NEXT: "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0) +// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island wraps "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0) // CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]] // Match the inserted Identity op for call output. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir new file mode 100644 index 00000000000..42721d2a406 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -0,0 +1,25 @@ +// RUN: tf-opt %s -test-tf-lower-tf | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: simple_pack +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32>, %[[ARG1:.*]]: tensor<3x5xf32> +func @simple_pack(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<2x3x5xf32> { + // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<3x5xf32>, tensor) -> tensor<1x3x5xf32> + // CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<3x5xf32>, tensor) -> tensor<1x3x5xf32> + // CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor<1x3x5xf32>, tensor<1x3x5xf32>, tensor) -> tensor<2x3x5xf32> + + %0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<2x3x5xf32> + return %0 : tensor<2x3x5xf32> +} + +// CHECK-LABEL: pack_with_unranked +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<*xf32> +func @pack_with_unranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<-2> : tensor} + // CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor, tensor) -> tensor + // CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) {N = 2 : i64} : (tensor, tensor<*xf32>, tensor) -> tensor<*xf32> + + %0 = "tf.Pack"(%arg0, %arg1) {axis = -2 : i64, N = 2 : i64} : (tensor, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/materialize_passthrough_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/materialize_passthrough_op.mlir new file mode 100644 index 00000000000..dd695f0b871 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/materialize_passthrough_op.mlir @@ -0,0 +1,15 @@ +// RUN: tf-opt -tf-materialize-passthrough-op %s | FileCheck %s --dump-input=fail + + +// Check that the MlirPassthroughOp is eliminated and replaced by its attached +// MLIR module. + +// CHECK-LABEL: func @main +func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> { +// CHECK-SAME: (%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<10xf32>) +// CHECK-NEXT: %[[ADD:.*]] = "tf.Add"(%[[ARG0]], %[[ARG1]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK-NEXT: %[[MAGIC:.*]] = "magic.op"(%[[ADD]], %[[ADD]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return %[[MAGIC]] + %0 = "tf.MlirPassthroughOp"(%arg0, %arg1) {Tinputs = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Toutputs = ["tfdtype$DT_FLOAT"], device = "", mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir new file mode 100644 index 00000000000..f4addb85967 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir @@ -0,0 +1,23 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +// Verify the ops generated when Ref type is used in a while loop. +func @main() { + // CHECK: op: "RefEnter" + // CHECK: op: "RefMerge" + // CHECK: op: "RefSwitch" + // CHECK: op: "RefExit" + // CHECK: op: "RefNextIteration" + %0:2 = "_tf.NextIteration.source"() {device = "", T = "tfdtype$DT_INT32"} : () -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/NextIteration") + %1:2 = "_tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> (tensor, !_tf.control) loc("Ref_Variable") + %2:2 = "_tf.Enter"(%1#0) {device = "", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Enter") + %3:3 = "_tf.Merge"(%2#0, %0#0) {device = "", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, tensor, !_tf.control) loc("while/Merge") + %4:2 = "_tf.Const"(%3#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<10> : tensor} : (!_tf.control) -> (tensor, !_tf.control) loc("while/Less/y") + %5:2 = "_tf.Less"(%3#0, %4#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor) -> (tensor<*xi1>, !_tf.control) loc("while/Less") + %6:2 = "_tf.LoopCond"(%5#0) {device = ""} : (tensor<*xi1>) -> (tensor, !_tf.control) loc("while/LoopCond") + %7:3 = "_tf.Switch"(%3#0, %6#0) {device = "", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*x!tf.int32ref>, tensor) -> (tensor<*x!tf.int32ref>, tensor<*x!tf.int32ref>, !_tf.control) loc("while/Switch") + %8:2 = "_tf.Exit"(%7#1) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Exit") + %10:2 = "_tf.Const"(%7#2) {device = "", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor} : (!_tf.control) -> (tensor, !_tf.control) loc("while/Add/y") + %11:2 = "_tf.AssignAdd"(%7#0, %10#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>, tensor) -> (tensor<*x!tf.int32ref>, !_tf.control) loc("while/Add") + %12 = "_tf.NextIteration.sink"(%11#0) {device = "", T = "tfdtype$DT_INT32"} : (tensor<*x!tf.int32ref>) -> !_tf.control loc("while/NextIteration") + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir new file mode 100644 index 00000000000..10ff24a5336 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir @@ -0,0 +1,41 @@ +// RUN: tf-opt %s -tf-device-constant-sinking | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @sink_const +func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { + // Verify that the constant are sunk in the tf_device.launch region using them + // and removed if no other use is left. + + // Only the 2.0 and 3.0 constants are removed, the 4.0 has a use in the return + // CHECK-NOT:"tf.Const"2.0 + // CHECK-NOT:"tf.Const"3.0 + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<3.000000e+00> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<4.000000e+00> : tensor} : () -> tensor + %3 = tf_executor.graph { + %res, %ctl = tf_executor.island { + %3 = "tf_device.launch"() ({ + + // In the device region, check that the 3 constants are materialized and + // remapped to the uses. + // CHECK: tf_device.launch + // CHECK-DAG: %[[CST2:.*]] = "tf.Const"{{.*}}2.0 + // CHECK-DAG: %[[CST3:.*]] = "tf.Const"{{.*}}3.0 + // CHECK-DAG: %[[CST4:.*]] = "tf.Const"{{.*}}4.0 + // CHECK-NOT:"tf.Const" + // CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg0, %[[CST2]]) + // CHECK-NEXT: %[[MUL2:.*]] = "tf.Mul"(%[[MUL1]], %[[CST2]]) + // CHECK-NEXT: %[[MUL3:.*]] = "tf.Mul"(%[[MUL2]], %[[CST3]]) + // CHECK-NEXT: = "tf.Mul"(%[[MUL3]], %[[CST4]]) + %3 = "tf.Mul"(%arg0, %0) : (tensor<16xf32>, tensor) -> tensor<16xf32> + %4 = "tf.Mul"(%3, %0) : (tensor<16xf32>, tensor) -> tensor<16xf32> + %5 = "tf.Mul"(%4, %1) : (tensor<16xf32>, tensor) -> tensor<16xf32> + %6 = "tf.Mul"(%5, %2) : (tensor<16xf32>, tensor) -> tensor<16xf32> + "tf_device.return"(%6) : (tensor<16xf32>) -> () + }) {device = "tpu0"} : () -> tensor<16xf32> + tf_executor.yield %3 : tensor<16xf32> + } + tf_executor.fetch %res : tensor<16xf32> + } + return %3, %2 : tensor<16xf32>, tensor +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index dd6d77f7816..b702f5fe88c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -83,6 +83,15 @@ func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> { // ----- +// CHECK-LABEL: func @testReverseV2 +func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> { + // CHECK: tf.ReverseV2 + %0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> + return %0 : tensor<2x4x3x!tf.uint8> +} + +// ----- + func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> { // expected-error @+1 {{requires all operands to be either same as or ref type of results}} %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> @@ -459,6 +468,37 @@ func @testInvalidFakeQuantWithMinMaxVarsWrongMaxType(tensor<8x8x8x8xf32>, tensor // ----- +// Test valid tf.FakeQuantWithMinMaxVarsPerChannel +// CHECK-LABEL: func @FakeQuantWithMinMaxVarsPerChannel +func @FakeQuantWithMinMaxVarsPerChannel(tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> { +^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>): + // CHECK: "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> + %0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> + return %0 : tensor<1x2x3x8xf32> +} + +// ----- + +// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel +func @FakeQuantWithMinMaxVarsPerChannel_ranked_inputs(tensor, tensor<8xf32>, tensor<8xf32>) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>): + // expected-error @+1 {{requires inputs to be at least 1d float tensor}} + %0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor, tensor<8xf32>, tensor<8xf32>) -> tensor + return %0 : tensor +} + +// ----- + +// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel +func @FakeQuantWithMinMaxVarsPerChannel_mismatch_min_max(tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> { +^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<8xf32>): + // expected-error @+1 {{requires min and max to have same size as last dimension of inputs}} + %0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> + return %0 : tensor<1x2x3x8xf32> +} + +// ----- + // Test valid tf.FusedBatchNorm // CHECK-LABEL: func @testFusedBatchNorm func @testFusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> { @@ -944,25 +984,25 @@ func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> { // ----- // Test valid tf.ConcatV2 -func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor { - // CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor - %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xi32>) -> tensor +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { + // CHECK: %0 = "tf.ConcatV2"(%arg0, %arg0, %arg1) {N = 2 : i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor + %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor return %0 : tensor } // ----- // tf.ConcatV2 with wrong 'axis' element type -func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xf32>) -> tensor { +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { // expected-error @+1 {{operand #2 must be tensor of 32/64-bit integer values}} - %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1xf32>) -> tensor + %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor return %0 : tensor } // ----- // tf.ConcatV2 missing required 'axis' operand -func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor { +func @testConcatV2() -> tensor { // expected-error @+1 {{expected 1 or more operands}} %0 = "tf.ConcatV2"() {N = 0: i64} : () -> tensor return %0 : tensor @@ -971,9 +1011,165 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor, %axis: tensor<1xi32>) -> tensor { +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { // expected-error @+1 {{attribute 'N' failed to satisfy constraint: 64-bit integer attribute whose minimal value is 2}} - %0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor<1xi32>) -> tensor + %0 = "tf.ConcatV2"(%arg, %axis) {N = 1: i64} : (tensor<8x16xf32>, tensor) -> tensor return %0 : tensor } +// ----- + +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { + // expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 3}} + %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 3: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testAll(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor + + // CHECK-LABEL: testAll + // CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +} + +// ----- + +func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor + + // CHECK-LABEL: testAll64 + // CHECK: %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +} + +// ----- + +func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + // expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit integer values}} + %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor) -> tensor { + // expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit integer values}} + %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.Equal"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1> + return %0 : tensor<5xi1> +} + +// ----- + +// CHECK-LABEL: testEqualOpIncompatibleShapeFalse +func @testEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> { + // CHECK: tf.Equal + %0 = "tf.Equal"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} + +// ----- + +func @testNotEqualOpIncompatibleShapeTrue(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<5xi1> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = true} : (tensor<5xf32>, tensor<4xf32>) -> tensor<5xi1> + return %0 : tensor<5xi1> +} + +// ----- + +// CHECK-LABEL: testNotEqualOpIncompatibleShapeFalse +func @testNotEqualOpIncompatibleShapeFalse(%x: tensor<5xf32>, %y: tensor<4xf32>) -> tensor<*xi1> { + // CHECK: tf.NotEqual + %0 = "tf.NotEqual"(%x, %y) {incompatible_shape_error = false} : (tensor<5xf32>, tensor<4xf32>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} + +// ----- + +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> { // expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}} + %0 = "tf.ConcatV2"(%arg, %arg, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<1x1xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1x1xi32>) -> tensor<*xf32> { + // expected-error @+1 {{requires axis to be of scalar type (or vector type for older versions)}} + %0 = "tf.Concat"(%axis, %arg, %arg) {N = 2: i64} : (tensor<1x1xi32>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8xf32>, %axis: tensor) -> tensor<*xf32> { + // expected-error @+1 {{operand type 'tensor<8xf32>' is not compatible with preceding operands; expected rank: 2}} + %0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// Valid Concat operation with concat axis 1 or -1. +func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8x8xf32>, %axis: tensor) -> tensor<*xf32> { + %0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<8x8xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testConcatV2(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>, %axis: tensor) -> tensor<*xf32> { + // expected-error @+1 {{operand type 'tensor<16x8xf32>' is not compatible with preceding operands; expected dimension at index 1: 16}} + %0 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2: i64} : (tensor<8x16xf32>, tensor<16x8xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// Valid Concat operation with concat axis 1 or -1. +func @testConcatV2(%arg0: tensor<8x8xf32>, %arg1: tensor, %arg2: tensor<*xf32>, %arg3: tensor<8x?xf32>, %axis: tensor) -> tensor<*xf32> { + %0 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %axis) {N = 4: i64} : (tensor<8x8xf32>, tensor, tensor<*xf32>, tensor<8x?xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// Valid Pack operation. +func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> { + %0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<*xf32> { + // expected-error @+1 {{requires attribute 'N' to match the number of inputs; expected: 2 Found: 1}} + %0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 1: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{operand type 'tensor<4x2xf32>' is not compatible with preceding operands; expected dimension at index 1: 8}} + %0 = "tf.Pack"(%arg0, %arg1) {axis = 1 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor) -> tensor<*xf32> { + // expected-error @+1 {{attribute 'axis' should be within range [-3, 3); actual value: 3}} + %0 = "tf.Pack"(%arg0, %arg1) {axis = 3 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 2890656c013..fca724e196a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -271,6 +271,39 @@ func @merge_with_variant_type(%arg0: tensor, %arg1: tensor>> } +// CHECK-LABEL: func @merge_with_ref_type +func @merge_with_ref_type(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4xf32>, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor<4xf32> + } + return %result : tensor<4xf32> +} + +// CHECK-LABEL: func @merge_with_dynamic_shape +func @merge_with_dynamic_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor + } + return %result : tensor +} + +// CHECK-LABEL: func @merge_with_unranked_shape +func @merge_with_unranked_shape(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> tensor<*xf32> { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor<*xf32> + } + return %result : tensor<*xf32> +} + // CHECK-LABEL: func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index ee3d2b91732..5803cc7b516 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -490,7 +490,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> %value, %idx, %ctlMerge = "tf_executor.Merge"(%true, %false, %arg1) : (tensor<*xf32>, tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<*xf32>' vs 'tensor'}} +// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor' vs 'tensor<*xf32>'}} tf_executor.fetch %value : tensor<*xf32> } return %result : tensor<*xf32> @@ -502,7 +502,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { %result = tf_executor.graph { %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xf32>, tensor<4xf32>) -> (tensor<8xf32>, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8xf32>' vs 'tensor<4xf32>'}} +// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4xf32>' vs 'tensor<8xf32>'}} tf_executor.fetch %value : tensor<8xf32> } return %result : tensor<8xf32> @@ -514,7 +514,7 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32> func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) -> tensor<8x!tf.variant> { %result = tf_executor.graph { %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*x!tf.variant>, tensor<4x!tf.variant>) -> (tensor<8x!tf.variant>, tensor, !tf_executor.control) -// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8x!tf.variant>' vs 'tensor<4x!tf.variant>'}} +// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable with output type but got 'tensor<4x!tf.variant>' vs 'tensor<8x!tf.variant>'}} tf_executor.fetch %value : tensor<8x!tf.variant> } return %result : tensor<8x!tf.variant> @@ -522,6 +522,18 @@ func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) // ----- +// Check that if result is a ref type, all operands need to be ref too. +func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor, !tf_executor.control) + // expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} + tf_executor.fetch %value : tensor<4x!tf.f32ref> + } + return %result : tensor<4x!tf.f32ref> +} + +// ----- + // Check that merge data inputs can't appear after control input. func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index dc2f60b6441..e91e772d47f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -11,9 +11,9 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: NumDynamicShapes = 1 + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-NOT: func = @tpu0_func @@ -68,9 +68,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-SAME: func @nested_func @@ -112,9 +111,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-SAME: func @referenced_func @@ -155,9 +153,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-SAME: @referenced_func1 @@ -206,9 +203,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-COUNT-2: call @referenced_func @@ -251,9 +247,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func0} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-NOT: func = @tpu0_func0 @@ -263,9 +258,8 @@ module { %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func1} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) - // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster1" - // CHECK-SAME: module + // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.D // CHECK-NOT: func = @tpu0_func1 @@ -303,9 +297,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-NOT: func = @tpu0_func @@ -315,9 +308,8 @@ module { %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) - // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster1" - // CHECK-SAME: module + // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-NOT: func = @tpu0_func @@ -351,9 +343,8 @@ module { %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) - // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: _tpu_replicate = "cluster0" - // CHECK-SAME: module + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: mlir_module // CHECK-SAME: func @main // CHECK-SAME: tf.B // CHECK-SAME: func @referenced_func2 @@ -404,3 +395,44 @@ module { } +// ----- + + +// Tests that TPUCompilationResult operations are properly rewritten + +// CHECK-LABEL: func @tpu_compilation_result +func @tpu_compilation_result(%arg0: tensor) -> (tensor, tensor, tensor) { + + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir" + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute" + %1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + + %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor + %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor + + // CHECK: return %[[EXECUTE_OUTPUT]], %[[COMPILE_OUTPUT]]#0, %[[COMPILE_OUTPUT]]#0 + return %1, %compile_result, %compile_result2 : tensor, tensor, tensor +} + +func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + + +// ----- + +// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten + +func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> { + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1 + %0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + %1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + %2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + %3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32> + return %3 : tensor<0xf32> +} +func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + %0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32> + return %0 : tensor<0xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index a186837bb79..4655f0f8e41 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -23,10 +23,11 @@ limitations under the License. namespace mlir { namespace TFTPU { -void createTPUBridge(PassManager &bridge) { +void createTPUBridge(OpPassManager &bridge) { bridge.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); bridge.addPass(createCanonicalizerPass()); bridge.addPass(CreateTPUClusterFormationPass()); + bridge.addPass(tf_executor::CreateTFExecutorConstantSinkingPass()); bridge.addPass(TFDevice::CreateClusterOutliningPass()); bridge.addPass(CreateTPURewritePass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 0653c1d109e..ebdc11b8fbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -93,11 +93,13 @@ def LogOfSoftmax : Pat<(TF_LogOp (TF_SoftmaxOp $arg)), (TF_LogSoftmaxOp $arg)>; def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)), (replaceWithValue $arg)>; -def LogicalNotOfEqual : Pat<(TF_LogicalNotOp (TF_EqualOp $arg0, $arg1)), - (TF_NotEqualOp $arg0, $arg1)>; +def LogicalNotOfEqual : Pat< + (TF_LogicalNotOp (TF_EqualOp $arg0, $arg1, $shape_error)), + (TF_NotEqualOp $arg0, $arg1, $shape_error)>; -def LogicalNotOfNotEqual : Pat<(TF_LogicalNotOp (TF_NotEqualOp $arg0, $arg1)), - (TF_EqualOp $arg0, $arg1)>; +def LogicalNotOfNotEqual : Pat< + (TF_LogicalNotOp (TF_NotEqualOp $arg0, $arg1, $shape_error)), + (TF_EqualOp $arg0, $arg1, $shape_error)>; def LogicalNotOfGreater : Pat<(TF_LogicalNotOp (TF_GreaterOp $arg0, $arg1)), (TF_LessEqualOp $arg0, $arg1)>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc new file mode 100644 index 00000000000..c65544ed5e1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" + +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +// Infers ExpandDims op output type for the given input type `ty` and dimension +// to expand at the given `axis`. +Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { + auto ranked_ty = ty.dyn_cast(); + + // Unranked type. + if (!ranked_ty) return ty; + + auto shape = llvm::to_vector<4>(ranked_ty.getShape()); + if (axis < 0) axis += ranked_ty.getRank() + 1; + + shape.insert(shape.begin() + axis, 1); + return builder->getTensorType(shape, ranked_ty.getElementType()); +} + +// Lowers Pack op to ConcatV2 op after changing shape of the inputs with +// ExpandDims op. +// +// Sample result with 2 inputs to pack: +// +// %axis = "tf.Const"() {value = dense<1> : tensor} +// %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32> +// %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32> +// %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }: +// +class LowerPackOp : public OpRewritePattern { + public: + explicit LowerPackOp(MLIRContext *context) + : OpRewritePattern(context) {} + + PatternMatchResult matchAndRewrite(TF::PackOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto axis_value = rewriter.create( + loc, DenseElementsAttr::get( + rewriter.getTensorType({}, rewriter.getIntegerType(64)), + op.axis())); + int64_t axis = op.axis().getLimitedValue(); + + Type prev_input_ty, inferred_ty; + SmallVector expanded_inputs; + expanded_inputs.reserve(op.N().getLimitedValue()); + for (Value *input : op.values()) { + // If input type is different than the previous input type, infer the + // output type. Otherwise, use the already inferred output type from the + // previous iteration. + Type input_ty = input->getType(); + if (input_ty != prev_input_ty) { + inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter); + prev_input_ty = input_ty; + } + expanded_inputs.push_back(rewriter.create( + loc, inferred_ty, input, axis_value)); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), expanded_inputs, axis_value, + op.getAttrOfType("N")); + return matchSuccess(); + } +}; + +} // namespace + +void PopulateLoweringTFPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h new file mode 100644 index 00000000000..4b85ac3b46a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ + +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir + +namespace mlir { +namespace TF { + +// Populates TensorFlow lowering patterns to lower some of the TensorFlow +// operations that can be represented using other TensorFlow operations. +void PopulateLoweringTFPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc new file mode 100644 index 00000000000..309d0147bc0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_pass.cc @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" + +namespace mlir { +namespace TF { +namespace { + +// Lowers some of the TensorFlow operations that can be represented using other +// TensorFlow operations. +struct LowerTF : public FunctionPass { + void runOnFunction() override { + // Add lowering patterns to the list. + OwningRewritePatternList patterns; + mlir::TF::PopulateLoweringTFPatterns(&getContext(), &patterns); + + applyPatternsGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace TF +} // namespace mlir + +static mlir::PassRegistration pass( + "test-tf-lower-tf", + "Lowers some of the TensorFlow ops to other TensorFlow ops"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc new file mode 100644 index 00000000000..0f74fda2336 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Parser.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define DEBUG_TYPE "tf-materialize-passthrough-op" + +namespace mlir { +namespace { + +class MaterializePassthroughOpPass + : public FunctionPass { + public: + void runOnFunction() override; +}; + +void MaterializePassthroughOpPass::runOnFunction() { + getFunction().walk([](Operation *op) { + auto passthrough_op = dyn_cast(op); + if (!passthrough_op) return; + std::string module_string = passthrough_op.mlir_module(); + // Parse the module. + auto nested_module = parseSourceString(module_string, op->getContext()); + if (!nested_module) { + op->emitError() << "could not parse attached MLIR module"; + return; + } + FuncOp main = dyn_cast(nested_module->lookupSymbol("main")); + if (!main) { + op->emitError() << "MLIR Opaque Op expects a main() entry point\n"; + return; + } + if (main.getNumArguments() != op->getNumOperands()) { + op->emitError() << "mismatch between MLIR Opaque Op number of operands (" + << op->getNumOperands() + << ") and main() entry point in the module (" + << main.getNumArguments() << " args)\n"; + return; + } + if (main.getType().getNumResults() != op->getNumResults()) { + op->emitError() << "mismatch between MLIR Opaque Op number of results (" + << op->getNumResults() + << ") and main() entry point in the module (" + << main.getType().getNumResults() << " results)\n"; + return; + } + Region &body = main.getBody(); + if (body.getBlocks().size() != 1) { + op->emitError() << "MLIR Opaque Op expects a main() entry point with a " + "single block\n"; + return; + } + Block &block = body.front(); + for (const auto &arg_mapping : + llvm::zip(block.getArguments(), op->getOperands())) { + std::get<0>(arg_mapping)->replaceAllUsesWith(std::get<1>(arg_mapping)); + } + op->getBlock()->getOperations().splice(op->getIterator(), + block.getOperations(), block.begin(), + std::prev(block.end())); + Operation &return_op = block.front(); + for (auto ret_mapping : + llvm::zip(op->getResults(), return_op.getOperands())) { + std::get<0>(ret_mapping)->replaceAllUsesWith(std::get<1>(ret_mapping)); + } + op->erase(); + }); +} + +} // namespace + +namespace TF { +std::unique_ptr CreateMaterializePassthroughOpPass() { + return std::make_unique(); +} +} // namespace TF + +static PassRegistration pass( + "tf-materialize-passthrough-op", + "Materialize the MlirPassthroughOp by replacing it with the MLIR module " + "attached as an attribute"); + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 9501f49475f..15ddebdffe8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -26,6 +26,10 @@ namespace TF { // dialect to MLIR Control Flow Graph (CFG) form. std::unique_ptr CreateTFFunctionalControlFlowToCFG(); +// Materialize the MlirPassthroughOp by replacing it with the MLIR module +// attached as an attribute. +std::unique_ptr CreateMaterializePassthroughOpPass(); + // Optimizes Tensorflow graph. std::unique_ptr CreateTFOptimizePass(); @@ -53,6 +57,11 @@ std::unique_ptr CreateTFExecutorGraphPruningPass(); // Prune a tf_executor.graph operation from dead nodes. void prune_graph(GraphOp graph); +// Sink `tf.Const` operations in the LaunchOp region using them. This is +// performed in order to limit the number of values implicitly captured in this +// region before outlining. +std::unique_ptr CreateTFExecutorConstantSinkingPass(); + } // namespace tf_executor namespace TFDevice { @@ -75,7 +84,7 @@ std::unique_ptr CreateTPURewritePass(); // Populates the supplied passmanager with the passes required to run the // bridge. NOLINTNEXTLINE - MLIR contract is pass by mutable reference. -void createTPUBridge(PassManager& bridge); +void createTPUBridge(OpPassManager& bridge); } // namespace TFTPU diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc new file mode 100644 index 00000000000..86344e5fa3e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" + +#define DEBUG_TYPE "tf-executor-sink-constant" + +namespace mlir { +namespace tf_executor { + +namespace { +using ::mlir::TF::ConstOp; + +class ExecutorConstantSinking + : public mlir::FunctionPass { + void runOnFunction() override { + getFunction().walk([](tf_device::LaunchOp launch) { + LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n"); + // For each launch op, we find the values used that come from a constant + // defined above and sink these constants in the region body. + // The sunk_constant map keeps a mapping from a ConstOp defined above to + // a sunk clone of it. This allows for reusing a sunk constant with + // multiple uses in the region. + llvm::DenseMap sunk_constant; + Region &body = launch.body(); + visitUsedValuesDefinedAbove(body, [&](OpOperand *use) { + Value *constant = use->get(); + auto const_op = + dyn_cast_or_null(constant->getDefiningOp()); + if (!const_op) return; + + // We found a constant, try to insert it in the map and re-use its + // cloned value if any. + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << *use->get() + << "\n in " << *use->get() << "\n"); + if (constant->use_empty()) const_op.erase(); + return; + } + if (constant->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() << "Moved constant " << *constant << "\n"); + const_op.getOperation()->moveBefore(&body.begin()->front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + body.begin()->getOperations().insert(body.begin()->begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << *use->get() + << "\n in " << *use->get() << "\n"); + }); + }); + } +}; + +static mlir::PassRegistration pass( + "tf-device-constant-sinking", + "Sink constants implicitly captured in a tf_device.launch region. This " + "reduces the number of arguments when outlining later."); + +} // anonymous namespace + +std::unique_ptr CreateTFExecutorConstantSinkingPass() { + return std::make_unique(); +} + +} // namespace tf_executor +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 05de214f992..91fc073e1f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir @@ -22,6 +23,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -137,7 +139,7 @@ std::string EncapsulateFuncAndSerialize(FuncOp entry_func) { Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, OpBuilder* builder) { // TODO(b/139377366): Use tf_tpu.compile build method when it is defined. - OperationState compile_op_state(launch_func.getLoc(), "tf.MLIRCompileToTPU"); + OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir"); // Build a shape op for each input to launch_func. // TODO(b/139377366): When shape inference is ready, we can use compile time @@ -153,6 +155,9 @@ Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, compile_op_operands.emplace_back(shape_op.getResult()); } compile_op_state.addOperands(compile_op_operands); + compile_op_state.addAttribute( + "NumDynamicShapes", + builder->getI64IntegerAttr(compile_op_operands.size())); SymbolRefAttr func_attr = launch_func.getAttrOfType("func"); if (!func_attr) { @@ -163,13 +168,8 @@ Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, func_attr.getValue()); std::string txt_module = EncapsulateFuncAndSerialize(func); - compile_op_state.addAttribute("module", builder->getStringAttr(txt_module)); - - // Copy all launch_func attributes other than `func`. - for (auto attr : launch_func.getAttrs()) { - if (attr.first == "func") continue; - compile_op_state.attributes.emplace_back(attr); - } + compile_op_state.addAttribute("mlir_module", + builder->getStringAttr(txt_module)); // Result #0 is a string indicating whether compilation is successful or not. compile_op_state.addTypes( @@ -239,8 +239,21 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // Operations that jit-compiles and executes function in `tf_device.launch_func` // on TPU. void Rewrite(tf_device::LaunchFuncOp launch_func, OpBuilder* builder) { + // Skip non-tpu device launch_func. + auto replicate_attr = launch_func.getAttrOfType("_tpu_replicate"); + if (!replicate_attr) return; + builder->setInsertionPoint(launch_func); Operation* compile_op = BuildCompileOp(launch_func, builder); + + // After rewrite, find if there is a TPUCompilationResultOp in the block with + // the same _tpu_replicate attribute and replace it with the result of the + // compile op. This op is used as a placeholder to hook during graph creation + // the other ops that are intended to consume the compile result. + Block* block = launch_func.getOperation()->getBlock(); + for (auto compile_result_op : block->getOps()) + compile_result_op.output()->replaceAllUsesWith(compile_op->getResult(0)); + BuildTPUCompileSucceededAssertOp(compile_op, builder); // TODO(ycao): Right now we only support single-core case. The right thing to // do is to read from launch_func attributes to determine how many execute @@ -253,11 +266,20 @@ void Rewrite(tf_device::LaunchFuncOp launch_func, OpBuilder* builder) { void TPURewritePass::runOnModule() { OpBuilder builder(&getContext()); getModule().walk([&](tf_device::LaunchFuncOp op) { - // Skip non-tpu device launch_func. - if (!op.getAttrOfType("_tpu_replicate")) return; Rewrite(op, &builder); }); + // Eliminate TPUReplicatedInput and TPUReplicatedOutput now that the rewrite + // is complete. + getModule().walk([&](Operation* op) { + auto op_name = op->getName().getStringRef(); + if (op_name != "tf.TPUReplicatedInput" && + op_name != "tf.TPUReplicatedOutput") + return; + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + }); + // TODO(b/139377366): Remove functions that are no longer needed. } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 9c1ffd4466e..e9aaf56462c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -67,8 +67,11 @@ namespace tensorflow { using llvm::cast; using llvm::dyn_cast; using llvm::isa; +using mlir::BlockArgument; using mlir::Dialect; using mlir::Operation; +using mlir::OperationState; +using mlir::Value; using stream_executor::port::StatusOr; namespace { @@ -125,35 +128,34 @@ class Exporter { explicit Exporter(Graph* graph, const Dialect* tf_dialect) : graph_(graph), tf_dialect_(tf_dialect) {} - Status AddArgumentNode(mlir::BlockArgument* arg, unsigned index); - Status AddInstructionNode(mlir::Operation* inst); - Status AddNextIterationNode(mlir::Operation* inst); - Status AddEdge(mlir::Operation* inst); + Status AddArgumentNode(BlockArgument* arg, unsigned index); + Status AddInstructionNode(Operation* inst); + Status AddNextIterationNode(Operation* inst); + Status AddEdge(Operation* inst); - StatusOr> GetArgumentNode(mlir::BlockArgument* arg, + StatusOr> GetArgumentNode(BlockArgument* arg, unsigned index); - StatusOr> GetReturnNode(mlir::Operation* inst, + StatusOr> GetReturnNode(Operation* inst, unsigned index); // Adds one edge between src_node and dst_node. If it is not a control edge, // an index is used to find out the right operand of the dst_node. - Status AddEdgeBetweenNodes(mlir::Value* src, Node* dst_node, - unsigned dst_index); + Status AddEdgeBetweenNodes(Value* src, Node* dst_node, unsigned dst_index); // Returns a unique name for `op`. - std::string UniqueName(mlir::Operation* op); + std::string UniqueName(Operation* op); // Returns a unique name starting with a given prefix. std::string UniqueName(llvm::StringRef prefix); Graph* graph_; - absl::flat_hash_map op_to_name_; + absl::flat_hash_map op_to_name_; absl::flat_hash_map name_to_count_; - absl::flat_hash_map nodes_; - absl::flat_hash_map args_; + absl::flat_hash_map nodes_; + absl::flat_hash_map args_; // One single return operation can return multiple results, and each of them // will be converted to one node in the graph. typedef absl::InlinedVector NodeVector; - absl::flat_hash_map returns_; + absl::flat_hash_map returns_; // Each NextIteration node in the original graph is converted to a pair of // source and sink operations in the MLIR, and we use the following two maps @@ -163,8 +165,8 @@ class Exporter { // are inserted to the name_to_inst_ first, and the other "sink" operation // can be paired by checking this map and both are inserted to the // source_to_sink_ map. - absl::flat_hash_map name_to_inst_; - absl::flat_hash_map source_to_sink_; + absl::flat_hash_map name_to_inst_; + absl::flat_hash_map source_to_sink_; const mlir::Dialect* tf_dialect_; }; @@ -183,15 +185,15 @@ std::string Exporter::UniqueName(llvm::StringRef prefix) { return name; } -std::string Exporter::UniqueName(mlir::Operation* op) { +std::string Exporter::UniqueName(Operation* op) { auto& name = op_to_name_[op]; if (!name.empty()) return name; name = UniqueName(GetName(op)); return name; } -StatusOr> Exporter::GetArgumentNode( - mlir::BlockArgument* arg, unsigned index) { +StatusOr> Exporter::GetArgumentNode(BlockArgument* arg, + unsigned index) { auto node_def = absl::make_unique(); node_def->set_name(UniqueName( arg->getParentRegion()->getParentOfType().getName().str())); @@ -208,8 +210,8 @@ StatusOr> Exporter::GetArgumentNode( return node_def; } -StatusOr> Exporter::GetReturnNode( - mlir::Operation* inst, unsigned index) { +StatusOr> Exporter::GetReturnNode(Operation* inst, + unsigned index) { auto node_def = absl::make_unique(); auto* inst_op = inst->getOperand(index); node_def->set_name( @@ -227,7 +229,7 @@ StatusOr> Exporter::GetReturnNode( return node_def; } -Status Exporter::AddEdgeBetweenNodes(mlir::Value* src, Node* dst_node, +Status Exporter::AddEdgeBetweenNodes(Value* src, Node* dst_node, unsigned dst_index) { if (auto* input_result = dyn_cast(src)) { auto* input_inst = input_result->getOwner(); @@ -236,25 +238,28 @@ Status Exporter::AddEdgeBetweenNodes(mlir::Value* src, Node* dst_node, if (it != source_to_sink_.end()) { input_inst = source_to_sink_[input_inst]; } - TF_RET_CHECK(nodes_.find(input_inst) != nodes_.end()) + auto node_it = nodes_.find(input_inst); + TF_RET_CHECK(node_it != nodes_.end()) << "Use of OpResult encountered before def!"; if (input_result->getType().isa()) { - graph_->AddControlEdge(nodes_[input_inst], dst_node); + graph_->AddControlEdge(node_it->second, dst_node); } else { - graph_->AddEdge(nodes_[input_inst], input_result->getResultNumber(), + graph_->AddEdge(node_it->second, input_result->getResultNumber(), dst_node, dst_index); } - } else if (auto* input_arg = dyn_cast(src)) { - TF_RET_CHECK(args_.find(input_arg) != args_.end()) - << "Use of BlockArgument encounted before def!"; - auto* input_node = args_[input_arg]; - // For argument, there is only one result output, so the index is always 0. - graph_->AddEdge(input_node, 0, dst_node, dst_index); + return Status::OK(); } + + auto* input_arg = cast(src); + auto input_node_it = args_.find(input_arg); + TF_RET_CHECK(input_node_it != args_.end()) + << "Use of BlockArgument encounted before def!"; + // For argument, there is only one result output, so the index is always 0. + graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index); return Status::OK(); } -Status Exporter::AddEdge(mlir::Operation* inst) { +Status Exporter::AddEdge(Operation* inst) { auto* dst_node = nodes_[inst]; bool is_return_op = isa(inst); for (int index = 0, e = inst->getNumOperands(); index < e; index++) { @@ -273,79 +278,86 @@ Status Exporter::AddEdge(mlir::Operation* inst) { return Status::OK(); } -Status Exporter::AddInstructionNode(mlir::Operation* inst) { +Status Exporter::AddInstructionNode(Operation* inst) { Status status; - if (!inst->isKnownTerminator()) { - std::unique_ptr node_def; - auto name = UniqueName(inst); - // Convert registered TF ops to NodeDef. Only registered ops are handled to - // ensure that PopulateDerivedAttrs adds the correct attributes. - TF_ASSIGN_OR_RETURN(node_def, - ConvertTFDialectOpToNodeDef( - inst, name, /*ignore_unregistered_attrs=*/false)); - Node* node = graph_->AddNode(*node_def, &status); - TF_RETURN_IF_ERROR(status); - nodes_[inst] = node; - } else if (isa(inst)) { - for (int index = 0, end = inst->getNumOperands(); index != end; index++) { + // If the op is a ReturnOp then create a return node per operand. + if (isa(inst)) { + auto& return_nodes = returns_[inst]; + for (int index : llvm::seq(0, inst->getNumOperands())) { TF_ASSIGN_OR_RETURN(auto node_def, GetReturnNode(inst, index)); Node* node = graph_->AddNode(*node_def, &status); TF_RETURN_IF_ERROR(status); - if (returns_.find(inst) == returns_.end()) { - returns_[inst] = NodeVector(); - } - returns_[inst].push_back(node); + return_nodes.push_back(node); } - } else { - return errors::InvalidArgument("Operation input was not an Value!"); + return Status::OK(); } + + if (inst->isKnownTerminator()) + return errors::InvalidArgument("std.return is only allowed terminator"); + + std::unique_ptr node_def; + auto name = UniqueName(inst); + // Convert registered TF ops to NodeDef. Only registered ops are handled to + // ensure that PopulateDerivedAttrs adds the correct attributes. + TF_ASSIGN_OR_RETURN(node_def, + ConvertTFDialectOpToNodeDef( + inst, name, /*ignore_unregistered_attrs=*/false)); + + Node* node = graph_->AddNode(*node_def, &status); + TF_RETURN_IF_ERROR(status); + nodes_[inst] = node; return Status::OK(); } -Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { - // If it is an argument from the "main" function, it has only one user, which - // is an input node. We recover the original input node and skip adding the - // argument node. The new input node will be handled as normal in the - // following steps. - if (arg->getParentRegion()->getParentOfType().getName() == - "main") { - if (!arg->hasOneUse()) { - return errors::FailedPrecondition( - "Arg in 'main' should only have one user."); - } - auto* input = *arg->user_begin(); - auto input_name = input->getName().getStringRef(); - input_name.consume_back(".input"); - mlir::OpBuilder builder(arg->getOwner()); - auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)), - builder.getContext()); - mlir::OperationState state(loc, input_name.str()); - state.attributes.append(input->getAttrs().begin(), input->getAttrs().end()); - for (auto* op : input->getOperands()) { - // Skip the argument in the new operation. - if (llvm::isa(op)) continue; - state.operands.push_back(op); - } - for (auto* r : input->getResults()) state.types.push_back(r->getType()); - auto* inst = builder.createOperation(state); - // If it is one of the specified input names, then the new - // instruction should have the same name. - op_to_name_[inst].assign(op_to_name_[input]); - for (int index = 0, e = input->getNumResults(); index != e; ++index) { - input->getResult(index)->replaceAllUsesWith(inst->getResult(index)); - } - input->dropAllReferences(); - input->erase(); - return Status::OK(); - } else { +bool IsEntryFunctionArg(BlockArgument* arg) { + return arg->getParentRegion()->getParentOfType().getName() == + "main"; +} + +Status Exporter::AddArgumentNode(BlockArgument* arg, unsigned index) { + if (!IsEntryFunctionArg(arg)) { TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index)); Status status; Node* node = graph_->AddNode(*node_def, &status); TF_RETURN_IF_ERROR(status); args_[arg] = node; - return Status::OK(); + return status; } + + // If it is an argument from the "main" function, it has only one user, which + // is an input node. We recover the original input node and skip adding the + // argument node. The new input node will be handled as normal in the + // following steps. + if (!arg->hasOneUse()) { + return errors::FailedPrecondition( + "Arg in 'main' should only have one user."); + } + auto* input = *arg->user_begin(); + auto input_name = input->getName().getStringRef(); + input_name.consume_back(".input"); + mlir::OpBuilder builder(arg->getOwner()); + auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)), + builder.getContext()); + OperationState state(loc, input_name.str()); + state.attributes.append(input->getAttrs().begin(), input->getAttrs().end()); + for (auto* op : input->getOperands()) { + // Skip the argument in the new operation. + if (llvm::isa(op)) continue; + state.operands.push_back(op); + } + state.types.append(input->getResultTypes().begin(), + input->getResultTypes().end()); + auto* inst = builder.createOperation(state); + // If it is one of the specified input names, then the new + // instruction should have the same name. + op_to_name_[inst].assign(op_to_name_[input]); + for (int index : llvm::seq(0, input->getNumResults())) { + input->getResult(index)->replaceAllUsesWith(inst->getResult(index)); + } + input->dropAllReferences(); + input->erase(); + return Status::OK(); } // Handles an NextIteration node specially: @@ -353,7 +365,7 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { // map by using its name attribute; // - NextIteration "sink" is paired with the "source" with the name attribute. // It is added to the graph like the other operations. -Status Exporter::AddNextIterationNode(mlir::Operation* inst) { +Status Exporter::AddNextIterationNode(Operation* inst) { auto name = GetName(inst); if (inst->getName().getStringRef().endswith(".source")) { name_to_inst_[name] = inst; @@ -363,10 +375,9 @@ Status Exporter::AddNextIterationNode(mlir::Operation* inst) { return AddInstructionNode(inst); } -StatusOr> Exporter::Convert(const ExporterConfigs& confs, - const Dialect* tf_dialect, - mlir::FuncOp function, - FunctionDefLibrary* flib) { +StatusOr> Exporter::Convert( + const ExporterConfigs& configs, const Dialect* tf_dialect, + mlir::FuncOp function, FunctionDefLibrary* flib) { if (function.getBlocks().size() != 1) { return errors::FailedPrecondition( "Input FuncOp must have only one basic block!"); @@ -420,10 +431,10 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, TF_RET_CHECK(output_names.size() == term->getNumOperands()) << "output names (" << output_names.size() << ") != terminator operands (" << term->getNumOperands() << ")"; - int i = 0; - for (auto it : term->getOperands()) { - exporter.name_to_count_[output_names[i].str()] = 1; - exporter.op_to_name_[it->getDefiningOp()] = output_names[i++]; + for (auto it : llvm::enumerate(term->getOperands())) { + exporter.name_to_count_[output_names[it.index()].str()] = 1; + exporter.op_to_name_[it.value()->getDefiningOp()] = + output_names[it.index()]; } } if (!input_names.empty()) { @@ -435,8 +446,9 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, } // Adds nodes for basic block (function) arguments. - for (int index = 0, e = block.getNumArguments(); index != e; index++) { - auto* arg = block.getArgument(index); + for (auto it : llvm::enumerate(block.getArguments())) { + int index = it.index(); + auto* arg = it.value(); mlir::Type type = arg->getType(); if (!type.isa()) { return errors::InvalidArgument( @@ -447,7 +459,7 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, TF_RETURN_IF_ERROR(exporter.AddArgumentNode(arg, index)); } // Adds nodes for operations. - for (mlir::Operation& inst : block) { + for (Operation& inst : block) { auto op_name = GetTensorFlowOpName(inst.getName().getStringRef()); if (op_name.ok()) { // If it is TF Control dialect specific op, look up custom operation @@ -459,13 +471,12 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, function.getParentOfType().lookupSymbol( op_name.ValueOrDie()); if (func != nullptr) { - TF_RETURN_IF_ERROR(ConvertLibFunction(confs, tf_dialect, func, flib)); + TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib)); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); } } - for (auto* result : inst.getResults()) { - mlir::Type type = result->getType(); + for (auto type : inst.getResultTypes()) { if (!type.isa() && !type.isa()) { return errors::InvalidArgument( @@ -481,7 +492,7 @@ StatusOr> Exporter::Convert(const ExporterConfigs& confs, } } // Adds edges between the argument, operation and return nodes. - for (mlir::Operation& inst : block) { + for (Operation& inst : block) { TF_RETURN_IF_ERROR(exporter.AddEdge(&inst)); } // Fixes the edges between the inserted nodes and special "_SOURCE" and @@ -584,7 +595,7 @@ Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs, } } // namespace -Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs, +Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { mlir::PassManager pass_manager(module.getContext()); @@ -593,24 +604,24 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs, return errors::FailedPrecondition( "Failed to convert TFExecutor Dialect to Control Dialect."); } - return Exporter::Convert(module, confs, graph, flib_def); + return Exporter::Convert(module, configs, graph, flib_def); } StatusOr> ConvertMlirToGraphdef( - mlir::ModuleOp module, const ExporterConfigs& confs) { + mlir::ModuleOp module, const ExporterConfigs& configs) { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); auto graph = absl::make_unique(flib_def); - TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, confs, &graph, &flib_def)); + TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def)); auto graphdef = absl::make_unique(); graph->ToGraphDef(graphdef.get()); - if (!confs.export_library) graphdef->clear_library(); - if (!confs.export_shapes) { + if (!configs.export_library) graphdef->clear_library(); + if (!configs.export_shapes) { for (auto& node_def : *graphdef->mutable_node()) { node_def.mutable_attr()->erase("shape"); } } - if (!confs.export_debug_info) { + if (!configs.export_debug_info) { for (auto& node_def : *graphdef->mutable_node()) { node_def.clear_experimental_debug_info(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 34cdc609164..c06bd3ec5c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -476,7 +477,8 @@ Status ImporterBase::AddNodesToShapeRefiner() { auto it = specs_.inputs.find(node->name()); if (it != specs_.inputs.end()) { auto node_name = node->op_def().name(); - if (node_name != "Placeholder" && node_name != "LegacyFedInput") { + if (node_name != "Placeholder" && node_name != "LegacyFedInput" && + node_name != "_Arg") { // We do not handle the case where the input node has multple outputs if (node->num_outputs() > 1) { return errors::FailedPrecondition(absl::StrCat( @@ -496,6 +498,35 @@ Status ImporterBase::AddNodesToShapeRefiner() { TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node), GetLocationStr(*node)); + // We currently have no other way to get shapes from ReadVariableOp's. + // Some graphs seem to have _output_shapes attributes on them, so use that + // if possible. + // TODO(silvasean): Ideally, we would do this in a separate shape inference + // pass to avoid adding complexity to the importer. But right now, we don't + // have an MLIR-native shape inference pass, so we need to do this while we + // still have the Graph around, i.e. here, in the importer. + if (node->op_def().name() == "ReadVariableOp") { + // TODO(silvasean): In some graphs, this seems to be annotated on every + // node. Why and by whom? + // TODO(b/140588338): We should ideally incorporate that information for + // all nodes, but right now, this can result in e.g. an Identity node with + // signature such as + // `(tensor) -> tensor` which fails the verifier + // (which checks for exact type equality; _output_shapes results in + // us shoehorning in the more-precise type on the output). + if (const AttrValue* attr = node->attrs().Find("_output_shapes")) { + auto& list = attr->list(); + for (auto shape : llvm::enumerate(list.shape())) { + auto* node_context = shape_refiner_->GetContext(node); + shape_inference::ShapeHandle handle; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + node_context->MakeShapeFromShapeProto(shape.value(), &handle), + GetLocationStr(*node)); + node_context->set_output(shape.index(), handle); + } + } + } + // If it is the argument node, the shape handle is set explicitly, so it // can be propagated to the body nodes of the function. if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { @@ -845,10 +876,28 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); } - // Converts the graph to a MLIR function and adds it to the module. Uses the - // default node spec without any inputs or outputs as the function graph has - // special '_Arg' and '_Retval' ops for argument and return values. + // Converts the graph to a MLIR function and adds it to the module. + // We populate the NodeSpec so that all the _Arg ops get their shape + // added correctly. NodeSpecs specs; + for (const auto& name_and_value : func_def->attr()) { + if (name_and_value.first == "_input_shapes") { + auto& list = name_and_value.second.list(); + auto& signature = func_def->signature(); + for (int i = 0; i < list.shape_size(); i++) { + auto& input_arg = signature.input_arg(i); + auto& array_info = specs.inputs[input_arg.name()]; + array_info.imported_dtype = input_arg.type(); + array_info.shape = list.shape(i); + // TODO(b/140464702): These fields should not be exposed here. + // Seems like a layering violation. Initialize them anyway. + array_info.final_dtype = input_arg.type(); + array_info.min_value = 0.0; + array_info.max_value = 0.0; + } + } + } + ImporterBase child_importer(graph_flib_, debug_info_, specs, module_, tf_name_to_mlir_name_); TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph)); @@ -1090,9 +1139,10 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { for (int i = 0, e = original_nodes.size(); i != e; ++i) { auto node_name = original_nodes[i]; auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; - // Use the catenation of function and node names as the lookup key. This - // is to match the utility of generating the GraphDebugInfo. - node_call_sites.push_back(node_name_to_call_site(func_name + node_name)); + // Use the catenation of function and node names as the lookup key. + // This matches the way that the key is formed on the python side. + std::string key = node_name + "@" + func_name; + node_call_sites.push_back(node_name_to_call_site(key)); } return mlir::FusedLoc::get(node_call_sites, context_); } @@ -1399,16 +1449,22 @@ StatusOr ImporterBase::InferLibFunctionType( const FunctionBody& fbody) { mlir::Builder builder(context_); + // The FunctionBody contains a graph with a single-output _Arg node for each + // function argument and a single-input _Retval node for each function return + // value. + // + // We already populated the ShapeRefiner with all the information about the + // shapes of these graph edges, so we just query it to build the corresponding + // MLIR function type signature. + llvm::SmallVector arg_types; arg_types.reserve(fbody.arg_types.size()); - for (auto dataType : fbody.arg_types) { - mlir::Type element_type; - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertDataType(dataType, builder, &element_type)); - // TODO(hinsu): Derive shape of function arguments based on shapes available - // at call sites of this function. That way it is possible to have a - // partially known shape in some cases instead of unranked tensor types. - arg_types.push_back(builder.getTensorType(element_type)); + for (auto arg : fbody.arg_nodes) { + // Find node in the graph using the node id instead of using `arg` directly + // because the graph has been cloned. + auto* node = graph_->FindNodeId(arg->id()); + TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder)); + arg_types.push_back(type); } llvm::SmallVector ret_types; @@ -1417,9 +1473,6 @@ StatusOr ImporterBase::InferLibFunctionType( // Find node in the graph using the node id instead of using `ret` directly // because the graph has been cloned. auto* node = graph_->FindNodeId(ret->id()); - - // Return type of the function is type of the only input of the respective - // return node in the function. TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); ret_types.push_back(type); } @@ -1721,4 +1774,13 @@ StatusOr ConvertSavedModelToMlir( add_default_attributes, context); } +std::string MlirModuleToString(mlir::ModuleOp module) { + std::string txt_module; + { + llvm::raw_string_ostream os{txt_module}; + module.print(os); + } + return txt_module; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 98bb607fa6a..6ca4c0098d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ +#include + #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "tensorflow/cc/saved_model/loader.h" @@ -48,6 +50,9 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( const SavedModelBundle& saved_model, const GraphDebugInfo& debug_info, mlir::MLIRContext* context, bool add_default_attributes = true); +// Serialize a MLIR module to a string. +std::string MlirModuleToString(mlir::ModuleOp m); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc new file mode 100644 index 00000000000..648aeef36da --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -0,0 +1,212 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" + +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Parser.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" + +namespace tensorflow { +namespace { + +// Parses the MLIR module from the mlir_module_string. +Status ParseMlirModule(llvm::StringRef mlir_module_string, + mlir::MLIRContext* mlir_context, + mlir::OwningModuleRef* mlir_module) { + TF_RET_CHECK(!mlir_module_string.empty()) + << "unexpected empty serialized MLIR module string"; + TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; + + // Parse the module. + *mlir_module = mlir::parseSourceString(mlir_module_string, mlir_context); + if (!*mlir_module) { + return errors::InvalidArgument("could not parse MLIR module"); + } + + return Status::OK(); +} + +// Converts arg_shapes to xla::Shape's and store into xla_input_shapes. +Status GetXlaInputShapes( + mlir::ModuleOp module, absl::Span arg_shapes, + const xla::CustomShapeRepresentationFn shape_representation_fn, + std::vector* xla_input_shapes) { + xla_input_shapes->clear(); + + mlir::FuncOp main_func = module.lookupSymbol("main"); + mlir::FunctionType func_type = main_func.getType(); + + int num_args = func_type.getNumInputs(); + xla_input_shapes->reserve(num_args); + + std::vector individual_arg_shapes; + individual_arg_shapes.reserve(num_args); + for (int i = 0; i < num_args; ++i) { + individual_arg_shapes.emplace_back(); + xla::Shape& xla_shape = individual_arg_shapes.back(); + + DataType dtype; + TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype)); + TF_ASSIGN_OR_RETURN(xla_shape, + shape_representation_fn(arg_shapes[i], dtype)); + } + xla_input_shapes->push_back( + xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); + return Status::OK(); +} + +// Calculates computation output shape and build OutputDescription for each +// output based on static shapes in MLIR module +Status GetOutputInfo( + mlir::ModuleOp module, + const xla::CustomShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_output_shape, + std::vector* outputs) { + mlir::FuncOp main_func = module.lookupSymbol("main"); + mlir::FunctionType func_type = main_func.getType(); + + outputs->clear(); + outputs->reserve(func_type.getNumResults()); + + std::vector shapes; + shapes.reserve(func_type.getNumResults()); + + for (mlir::Type type : func_type.getResults()) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, + TypeToShape(type, shape_representation_fn)); + auto tensor_type = type.dyn_cast(); + shapes.push_back(shape); + + // Construct OutputDescription for result. + outputs->emplace_back(); + XlaCompiler::OutputDescription& out_desc = outputs->back(); + TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type)); + // TODO(ycao): Support constant output. + out_desc.is_constant = false; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &out_desc.shape)); + // Input_index is only meaningful for resource output. Since MLIR-based + // TF-Compiler bridge doesn't support resource output yet. Setting it to + // meaningless value -1. + // TODO(ycao): Support resource-type output. + out_desc.input_index = -1; + // MLIR-based TF-Compiler bridge doesn't support tensorlist output yet. + // TODO(ycao): Support tensorlist-type output. + out_desc.is_tensor_list = false; + } + + // XLA computation always uses Tuple shape. + *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes); + return Status::OK(); +} + +// Gets information about how computation updates Tensorflow resources. +// TODO(ycao): Implement logic to compute resource updates when we need to +// support graphs with resource updates in MLIR-based TF compiler bridge. +void GetResourceUpdatesForMlir( + std::vector* resource_updates) { + resource_updates->clear(); +} + +// Creates a vector that maps from the parameters of the XLA computation to +// their original argument positions. +// MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no +// inputs are known constants. Therefore, the input mapping between input to +// computation arguments is a trivial in-order 1-1 mapping. +// TODO(ycao): Support computation with compile-time constant, which requires +// non-trivial input mapping as implemented now. +void GetInputMappingForMlir(int num_inputs, std::vector* input_mapping) { + input_mapping->resize(num_inputs, 0); + std::iota(input_mapping->begin(), input_mapping->end(), 0); +} + +// Lowers MLIR module to XLA HLO inside an XlaComputation. +Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, + xla::XlaComputation* xla_computation) { + { + // Make sure we catch any error reported by MLIR and forward it to the TF + // error reporting system. Report a generic error if pass manager failed + // without emitting a diagnostic. + mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext()); + mlir::xla_hlo::legalizeTF(module_op); + if (!error_handler.ok()) { + return error_handler.Combine( + errors::Internal("MLIR TF to XLA legalization failed")); + } + } + + xla::HloProto hlo_proto; + TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto, + /*use_tuple_args=*/true, + /*always_return_tuple=*/true)); + *xla_computation = xla::XlaComputation(hlo_proto.hlo_module()); + return Status::OK(); +} + +} // namespace + +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, absl::Span arg_shapes, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + auto module_op = mlir_module.get(); + + // Convert MLIR module to XLA HLO proto contained in XlaComputation. + compilation_result->computation = std::make_shared(); + TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( + module_op, compilation_result->computation.get())); + + // Construct mapping from XlaComputation's arg to input edges of execute + // node. + GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); + + auto shape_representation_fn_no_fast_memory = + [shape_representation_fn](const TensorShape& shape, DataType dtype) { + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + }; + + // Compute all input shapes. + TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, + shape_representation_fn_no_fast_memory, + &compilation_result->xla_input_shapes)); + + // Compute all output descriptions. + TF_RETURN_IF_ERROR(GetOutputInfo( + module_op, shape_representation_fn_no_fast_memory, + &compilation_result->xla_output_shape, &compilation_result->outputs)); + + // Compute what resource variables need to be updated after XlaComputation's + // execution. + GetResourceUpdatesForMlir(&compilation_result->resource_updates); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h new file mode 100644 index 00000000000..e7bfd264675 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ + +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// Compiles a serialized MLIR module into XLA HLO, generates all accompnaying +// metadata and stores them in CompilationResult. +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, absl::Span arg_shapes, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc new file mode 100644 index 00000000000..eee531a2550 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace { + +// A dummy shape representation function that simply converts given shape into +// an xla::Shape without assigning any layouts. +xla::StatusOr TestShapeRepresentation(const TensorShape& shape, + DataType type, + bool use_fast_memory) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + return xla_shape; +} + +TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) { + string invalid_mlir_module = "totally @invalid MLIR module {here} <-"; + std::vector arg_shapes; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + invalid_mlir_module, absl::Span(arg_shapes), + TestShapeRepresentation, &compilation_result); + EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); +} + +TEST(CompileSerializedMlirToXlaHloTest, Success) { + string mlir_module = R"( + module { + func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor + return %0 : tensor + } + } + )"; + + std::vector arg_shapes(2, TensorShape()); + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, absl::Span(arg_shapes), TestShapeRepresentation, + &compilation_result); + ASSERT_TRUE(s.ok()); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + string expected_hlo_module_string = R"(HloModule main.6 + +ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { + %arg_tuple.1 = (f32[], f32[]) parameter(0) + %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0 + %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1 + %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3) + ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4) +} + +)"; + EXPECT_EQ(status_or_hlo_module.ValueOrDie()->ToString(), + expected_hlo_module_string); + + // Expect an iota like input mapping. + EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); + + // Expect a single tuple-shape, containing two F32 scalars. + EXPECT_EQ(compilation_result.xla_input_shapes.size(), 1); + xla::Shape expected_input_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(compilation_result.xla_input_shapes.front(), expected_input_shape); + + // Expect output shape is a tuple shape containing a single F32 Scalar type. + const xla::Shape output_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); + const xla::Shape tuple_output_shape = + xla::ShapeUtil::MakeTupleShape({output_shape}); + EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); + + // Expect exactly 1 OutputDescrpition. + EXPECT_EQ(compilation_result.outputs.size(), 1); + const XlaCompiler::OutputDescription& output_desc = + compilation_result.outputs.front(); + EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); + EXPECT_EQ(output_desc.shape, TensorShape()); + EXPECT_FALSE(output_desc.is_constant); + EXPECT_FALSE(output_desc.is_tensor_list); + + // Expect no resource updates from computation. + EXPECT_TRUE(compilation_result.resource_updates.empty()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index dbb3cf08717..804b1372ffc 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -31,7 +31,9 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -184,6 +186,27 @@ void UpdateCompositeWhileOp(NodeDef* node_def) { } } +// Returns true if the control dialect op should map to Ref node in TensorFlow +// Graph. For NextIteration it uses the 1st operand type. For all others +// (Enter/Exit/Merge/Switch), if the output type is ref, +// they correspond to the Ref equivalent op in TF Graph. +static bool IsRefTypeControlOp(mlir::Operation* op) { + auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef()); + if (!op_name_or_status.ok()) return false; + + auto op_name = op_name_or_status.ConsumeValueOrDie(); + if (op_name.equals("NextIteration")) + return mlir::getElementTypeOrSelf(op->getOperand(0)->getType()) + .isa(); + + if (op_name.equals("Enter") || op_name.equals("Exit") || + op_name.equals("Switch") || op_name.equals("Merge")) { + return getElementTypeOrSelf(op->getResult(0)->getType()) + .isa(); + } + return false; +} + } // anonymous namespace StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { @@ -208,9 +231,21 @@ StatusOr> GetOperationNodeDef( auto node_def = absl::make_unique(); // Note: we do not use NodeBuilder or NodeDefBuilder as that would require // mapping back from the inputs to the input arguments. - TF_ASSIGN_OR_RETURN(auto op_name, + + // Some control flow ops in TensorFlow Graph have their respective "Ref" ops + // as well. For example there is Enter and RefEnter op. RefEnter forwards + // the input ref buffer to output. However both Enter and RefEnter are + // mapped to tf_executor::EnterOp during import and then to _tf.Enter op in + // control dialect. Check if it is a Ref op to correctly map to the TensorFlow + // Graph op. + llvm::SmallString<64> op_name; + if (IsRefTypeControlOp(inst)) op_name = "Ref"; + + TF_ASSIGN_OR_RETURN(auto tf_name, GetTensorFlowOpName(inst->getName().getStringRef())); - node_def->set_op(op_name); + op_name.append(tf_name); + + node_def->set_op(op_name.str()); node_def->set_name(name); // Add inputs to the NodeDef based on the number of operands. This is required diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 52fb7cac5b7..5be0ebd6894 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -37,6 +37,30 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) { namespace tensorflow { +Status LoadProtoFromBuffer(absl::string_view input, + tensorflow::protobuf::Message* proto) { + tensorflow::protobuf::TextFormat::Parser parser; + // Don't produce errors when attempting to parse text format as it would fail + // when the input is actually a binary file. + NoOpErrorCollector collector; + parser.RecordErrorsTo(&collector); + // Attempt to parse as text. + tensorflow::protobuf::io::ArrayInputStream input_stream(input.data(), + input.size()); + if (parser.Parse(&input_stream, proto)) { + return Status::OK(); + } + // Else attempt to parse as binary. + proto->Clear(); + tensorflow::protobuf::io::ArrayInputStream binary_stream(input.data(), + input.size()); + if (proto->ParseFromZeroCopyStream(&binary_stream)) { + return Status::OK(); + } + LOG(ERROR) << "Error parsing Protobuf"; + return errors::InvalidArgument("Could not parse input proto"); +} + Status LoadProtoFromFile(absl::string_view input_filename, tensorflow::protobuf::Message* proto) { auto file_or_err = @@ -45,26 +69,10 @@ Status LoadProtoFromFile(absl::string_view input_filename, return errors::InvalidArgument("Could not open input file"); auto& input_file = *file_or_err; - std::string content(input_file->getBufferStart(), - input_file->getBufferSize()); + absl::string_view content(input_file->getBufferStart(), + input_file->getBufferSize()); - tensorflow::protobuf::TextFormat::Parser parser; - // Don't produce errors when attempting to parse text format as it would fail - // when the input is actually a binary file. - NoOpErrorCollector collector; - parser.RecordErrorsTo(&collector); - // Attempt to parse as text. - if (parser.ParseFromString(content, proto)) { - return Status::OK(); - } - // Else attempt to parse as binary. - proto->Clear(); - std::istringstream istream(content); - if (proto->ParseFromIstream(&istream)) { - return Status::OK(); - } - LOG(ERROR) << "Error parsing Protobuf: " << input_filename; - return errors::InvalidArgument("Could not parse input file"); + return LoadProtoFromBuffer(content, proto); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h index 1158b9a6173..a7d00cf890e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h @@ -22,6 +22,11 @@ limitations under the License. namespace tensorflow { +// Reads text (.pbtext) or binary (.pb) format of a proto message from the given +// buffer. Returns error status of the file is not found or malformed proto. +Status LoadProtoFromBuffer(absl::string_view input, + tensorflow::protobuf::Message* proto); + // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // file path. Returns error status of the file is not found or malformed proto. Status LoadProtoFromFile(absl::string_view input_filename, diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 4a6f9c837aa..f70868e217f 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -55,8 +55,6 @@ static llvm::cl::opt verify_passes( llvm::cl::desc("Run the verifier after each transformation pass"), llvm::cl::init(true)); -static std::vector *pass_list; - int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); @@ -64,9 +62,8 @@ int main(int argc, char **argv) { mlir::registerPassManagerCLOptions(); // Parse pass names in main to ensure static initialization completed. - llvm::cl::list - pass_list("", llvm::cl::desc("Compiler passes to run")); - ::pass_list = &pass_list; + mlir::PassPipelineCLParser pass_pipeline("", "Compiler passes to run"); + llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR modular optimizer driver\n"); @@ -78,7 +75,7 @@ int main(int argc, char **argv) { auto output = mlir::openOutputFile(output_filename, &error_message); QCHECK(output) << error_message; - if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_list, + if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline, split_input_file, verify_diagnostics, verify_passes))) return 1; diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 39cd431165e..bcede30ea73 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -12,6 +12,7 @@ package_group( packages = [ "//babelfish/device/...", "//learning/brain/experimental/mlir/...", + "//learning/brain/google/xla/kernels/...", "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/xla/...", "//third_party/mlir_edge/...", @@ -33,6 +34,8 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "ir/hlo_ops.h.inc"), ("-gen-op-defs", "ir/hlo_ops.cc.inc"), + ("-gen-struct-attr-decls", "ir/hlo_structs.h.inc"), + ("-gen-struct-attr-defs", "ir/hlo_structs.cc.inc"), ], tblgen = "@local_config_mlir//:mlir-tblgen", td_file = "ir/hlo_ops.td", @@ -84,6 +87,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", @@ -224,9 +228,15 @@ cc_library( srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], deps = [ + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "//tensorflow/core:framework", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:types", + "@llvm//:support", "@local_config_mlir//:IR", "@local_config_mlir//:Support", ], @@ -241,6 +251,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_proto_cc", "//tensorflow/core:test_main", "@local_config_mlir//:IR", ], @@ -258,6 +269,7 @@ cc_library( ":type_to_shape", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", @@ -338,6 +350,7 @@ tf_native_cc_binary( deps = [ "@llvm//:support", "@llvm//:tablegen", + "@local_config_mlir//:Support", "@local_config_mlir//:TableGen", ], ) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 0f71859d9a1..d8b096cd85a 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -89,6 +89,17 @@ StatusOr CreateDenseAttrFromLiteral(ShapedType type, } #undef DENSE_ELEMENT_ATTR_BUILDER } + +// Returns whether the instruction is a default dot operation. +bool DotIsDefault(const HloInstruction* instruction) { + auto dot_dimensions = instruction->dot_dimension_numbers(); + DotDimensionNumbers default_dimension_numbers; + default_dimension_numbers.add_lhs_contracting_dimensions( + instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); + default_dimension_numbers.add_rhs_contracting_dimensions(0); + return xla::protobuf_util::ProtobufEquals(dot_dimensions, + default_dimension_numbers); +} } // namespace StatusOr HloFunctionImporter::ImportFunction( @@ -230,13 +241,18 @@ StatusOr HloFunctionImporter::ImportInstruction( MakeAndReturn(BroadcastInDimOp); } case HloOpcode::kDot: { - // TODO(b/129153247) Add support for batch and contracting dimensions. - TF_RETURN_IF_ERROR(ValidateDotDimensions(instruction)); - // TODO(b/129709049) The HLO text format elides this in the all DEFAULT // case and the parser sticks it in. Maybe we should too. attributes.push_back(ConvertPrecisionConfig(instruction)); - MakeAndReturn(DotOp); + + // Consider consolidating DotOps together. + if (DotIsDefault(instruction)) { + MakeAndReturn(DotOp); + } + + attributes.push_back(builder_->getNamedAttr( + "dot_dimension_numbers", ConvertDotDimensionNumbers(instruction))); + MakeAndReturn(DotGeneralOp); } case HloOpcode::kCall: { TF_ASSIGN_OR_RETURN(FuncOp function, @@ -311,7 +327,8 @@ StatusOr HloFunctionImporter::ImportInstruction( ->create( loc, result_type, operands[0], ConvertDimensions(instruction->slice_starts()), - ConvertDimensions(instruction->slice_limits())) + ConvertDimensions(instruction->slice_limits()), + ConvertDimensions(instruction->slice_strides())) .getOperation(); } case HloOpcode::kConcatenate: { @@ -324,8 +341,12 @@ StatusOr HloFunctionImporter::ImportInstruction( .getOperation(); } case HloOpcode::kReduce: { + // Operands in the first half are reduction inputs and the remaining + // operands are corresponding initial values. + size_t num_inputs = operands.size() / 2; auto reduce = func_builder->create( - loc, result_type, operands, + loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs), + llvm::makeArrayRef(operands).drop_front(num_inputs), ConvertDimensions(instruction->dimensions())); TF_RETURN_IF_ERROR( ImportComputation(instruction->to_apply(), &reduce.body())); @@ -537,37 +558,54 @@ mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection( ComparisonDirectionToString(instruction->comparison_direction()))); } -mlir::ElementsAttr HloFunctionImporter::ConvertDimensions( +mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions( llvm::ArrayRef op_dimensions) { llvm::SmallVector dimensions; dimensions.reserve(op_dimensions.size()); for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value)); return DenseIntElementsAttr::get( - builder_->getTensorType(dimensions.size(), builder_->getIntegerType(64)), - dimensions); + builder_->getTensorType(dimensions.size(), + builder_->getIntegerType(64)), + dimensions) + .cast(); } -mlir::ElementsAttr HloFunctionImporter::Convert( +mlir::DenseIntElementsAttr HloFunctionImporter::Convert( llvm::ArrayRef op_dimensions) { - return builder_->getDenseIntElementsAttr( - builder_->getTensorType(op_dimensions.size(), - builder_->getIntegerType(64)), - op_dimensions); + return builder_ + ->getDenseIntElementsAttr( + builder_->getTensorType(op_dimensions.size(), + builder_->getIntegerType(64)), + op_dimensions) + .cast(); } -Status HloFunctionImporter::ValidateDotDimensions(HloInstruction* instruction) { - DotDimensionNumbers expected_dimension_numbers; - expected_dimension_numbers.add_lhs_contracting_dimensions( - instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); - expected_dimension_numbers.add_rhs_contracting_dimensions(0); - if (!xla::protobuf_util::ProtobufEquals(instruction->dot_dimension_numbers(), - expected_dimension_numbers)) { - return tensorflow::errors::Internal( - absl::StrCat("Dot operation has unsupported dimension numbers: ", - instruction->dot_dimension_numbers().DebugString())); - } - return Status::OK(); +mlir::xla_hlo::DotDimensionNumbers +HloFunctionImporter::ConvertDotDimensionNumbers(HloInstruction* instruction) { + auto dot_dimensions = instruction->dot_dimension_numbers(); + std::vector rhs_contracting_dimensions( + dot_dimensions.rhs_contracting_dimensions().begin(), + dot_dimensions.rhs_contracting_dimensions().end()); + std::vector lhs_contracting_dimensions( + dot_dimensions.lhs_contracting_dimensions().begin(), + dot_dimensions.lhs_contracting_dimensions().end()); + std::vector rhs_batch_dimensions( + dot_dimensions.rhs_batch_dimensions().begin(), + dot_dimensions.rhs_batch_dimensions().end()); + std::vector lhs_batch_dimensions( + dot_dimensions.lhs_batch_dimensions().begin(), + dot_dimensions.lhs_batch_dimensions().end()); + + // Push the attributes into our new DictionaryAttr. + auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions); + auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions); + auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions); + auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions); + + return mlir::xla_hlo::DotDimensionNumbers::get( + lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr, + rhs_contracting_dims_attr, context_); } } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index b2e932bb09a..c6b61f94f5e 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -98,14 +99,15 @@ class HloFunctionImporter { xla::HloInstruction* instruction); // Converts the dimensions of an HLO instruction into an MLIR attribute. - mlir::ElementsAttr ConvertDimensions( + mlir::DenseIntElementsAttr ConvertDimensions( llvm::ArrayRef op_dimensions); - // Converts Array ref to an ElementsAttr. - mlir::ElementsAttr Convert(llvm::ArrayRef op_dimensions); + // Converts Array ref to an DenseIntElementsAttr. + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef op_dimensions); - // Ensures dot instruction has only default contracting and batch dimensions. - Status ValidateDotDimensions(xla::HloInstruction* instruction); + // Converts the dot dimensions to attributes. + mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( + xla::HloInstruction* instruction); mlir::MLIRContext* context_; mlir::ModuleOp module_; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index a5df379d90b..6c0f5179025 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -44,6 +44,10 @@ limitations under the License. #include "mlir/IR/Value.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" +namespace mlir { +#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc" +} // namespace mlir + using namespace mlir; using namespace mlir::xla_hlo; @@ -68,8 +72,10 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, return nullptr; } -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.cc.inc" +template +static LogicalResult Verify(T op) { + return success(); +} //===----------------------------------------------------------------------===// // ConstOp @@ -102,6 +108,35 @@ void ConstOp::build(Builder* builder, OperationState* result, Attribute value) { result->addAttribute("value", value); } +//===----------------------------------------------------------------------===// +// IotaOp +//===----------------------------------------------------------------------===// + +OpFoldResult IotaOp::fold(ArrayRef operands) { + const auto output_type = getResult()->getType().cast(); + const auto output_size = output_type.getNumElements(); + const auto dimension = iota_dimension().getLimitedValue(); + const auto max_dim_size = output_type.getDimSize(dimension); + int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); + + llvm::SmallVector values; + values.reserve(output_size); + + int64_t increase_stride = output_size; + for (int i = 0; i <= dimension; i++) { + increase_stride /= output_type.getDimSize(i); + } + + int64_t current_value = 0; + for (int i = 0; i < output_size; i++) { + int64_t value = (current_value / increase_stride) % max_dim_size; + values.push_back(APInt(bitwidth, value)); + ++current_value; + } + + return DenseIntElementsAttr::get(output_type, values); +} + //===----------------------------------------------------------------------===// // ConvertOp //===----------------------------------------------------------------------===// @@ -175,32 +210,211 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// -// IotaOp +// GetTupleElementOp //===----------------------------------------------------------------------===// -OpFoldResult IotaOp::fold(ArrayRef operands) { - const auto output_type = getResult()->getType().cast(); - const auto output_size = output_type.getNumElements(); - const auto dimension = iota_dimension().getLimitedValue(); - const auto max_dim_size = output_type.getDimSize(dimension); - int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); - - llvm::SmallVector values; - values.reserve(output_size); - - int64_t increase_stride = output_size; - for (int i = 0; i <= dimension; i++) { - increase_stride /= output_type.getDimSize(i); +static LogicalResult Verify(GetTupleElementOp op) { + auto indexVal = op.index().getZExtValue(); + auto operandType = op.getOperand()->getType().cast(); + if (indexVal >= operandType.size()) { + return op.emitOpError( + llvm::formatv("index {0} is out of bounds of operand with size {1}", + indexVal, operandType.size())); } - int64_t current_value = 0; - for (int i = 0; i < output_size; i++) { - int64_t value = (current_value / increase_stride) % max_dim_size; - values.push_back(APInt(bitwidth, value)); - ++current_value; + auto expectedType = operandType.getType(indexVal); + if (op.getType() != expectedType) { + return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}", + op.getType(), expectedType)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TupleOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TupleOp op) { + SmallVector operandTypes = {op.operand_type_begin(), + op.operand_type_end()}; + auto expectedType = TupleType::get(operandTypes, op.getContext()); + if (op.getType() != expectedType) { + return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}", + op.getType(), expectedType)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +// TODO(b/129012527) These should be expressed as type constraints. +static LogicalResult Verify(BroadcastOp op) { + auto sizes = op.broadcast_sizes(); + auto sizesType = sizes.getType(); + auto sizesRank = sizesType.getRank(); + if (sizesRank != 1) { + return op.emitOpError(llvm::formatv( + "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); } - return DenseIntElementsAttr::get(output_type, values); + auto resultType = op.getResult()->getType().cast(); + auto resultRank = resultType.getRank(); + auto operandType = op.operand()->getType().cast(); + auto operandRank = operandType.getRank(); + auto sizesSize = sizesType.getNumElements(); + auto expectedRank = operandRank + sizesSize; + + if (resultRank != expectedRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) does not match operand rank " + "({2}) plus size of broadcast_sizes ({3})", + resultRank, operandRank, sizesSize)); + } + + llvm::SmallVector expectedShape(sizes.getValues()); + + auto operandShape = operandType.getShape(); + expectedShape.insert(expectedShape.end(), operandShape.begin(), + operandShape.end()); + + auto resultShape = resultType.getShape(); + if (resultShape != llvm::makeArrayRef(expectedShape)) { + return op.emitOpError(llvm::formatv( + "result has shape [{0}] instead of [{1}]", + llvm::make_range(resultShape.begin(), resultShape.end()), + llvm::make_range(expectedShape.begin(), expectedShape.end()))); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// BroadcastInDimOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BroadcastInDimOp op) { + auto operandType = op.operand()->getType().cast(); + auto operandRank = operandType.getRank(); + if (!op.broadcast_dimensions()) { + if (operandRank == 0) { + return success(); + } + return op.emitOpError( + llvm::formatv("broadcast_dimensions is absent, but required because " + "operand has non-zero rank ({0})", + operandRank)); + } + + auto dimensions = *op.broadcast_dimensions(); + auto dimensionsType = op.broadcast_dimensions()->getType(); + auto dimensionsRank = dimensionsType.getRank(); + if (dimensionsRank != 1) { + return op.emitOpError(llvm::formatv( + "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); + } + + auto dimensionsSize = dimensionsType.getNumElements(); + if (dimensionsSize != operandRank) { + return op.emitOpError(llvm::formatv( + "broadcast_dimensions size ({0}) does not match operand rank ({1})", + dimensionsSize, operandRank)); + } + + auto resultType = op.getResult()->getType().cast(); + auto resultRank = resultType.getRank(); + if (resultRank < operandRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) is less than operand rank ({1})", + resultRank, operandRank)); + } + + for (int i = 0; i != dimensionsSize; ++i) { + auto dimIndex = dimensions.getValue(i); + if (dimIndex >= resultRank) { + return op.emitOpError( + llvm::formatv("broadcast_dimensions contains invalid value {0} for " + "result result with rank {1}", + dimIndex, resultRank)); + } + + auto dimSize = operandType.getDimSize(i); + auto resultDimSize = resultType.getDimSize(dimIndex); + if (dimSize != 1 && dimSize != resultDimSize) { + return op.emitOpError( + llvm::formatv("size of operand dimension {0} ({1}) is not equal to " + "1 or size of result dimension {2} ({3})", + i, dimSize, dimIndex, resultDimSize)); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ClampOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ClampOp op) { + auto operandType = op.operand()->getType().cast(); + auto operandShape = operandType.getShape(); + auto minType = op.min()->getType().cast(); + + auto minShape = minType.getShape(); + if (minShape != operandShape && minType.getRank() != 0) { + return op.emitOpError(llvm::formatv( + "min shape [{0}] is not scalar and does not match operand shape [{1}]", + llvm::make_range(minShape.begin(), minShape.end()), + llvm::make_range(operandShape.begin(), operandShape.end()))); + } + + auto maxType = op.max()->getType().cast(); + auto maxShape = maxType.getShape(); + if (maxShape != operandShape && maxType.getRank() != 0) { + return op.emitOpError(llvm::formatv( + "max shape [{0}] is not scalar and does not match operand shape [{1}]", + llvm::make_range(maxShape.begin(), maxShape.end()), + llvm::make_range(operandShape.begin(), operandShape.end()))); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ConcatenateOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConcatenateOp::fold(ArrayRef operands) { + if (getNumOperands() == 1) return getOperand(0); + return {}; +} + +static LogicalResult Verify(ConcatenateOp op) { + auto firstType = op.getOperand(0)->getType().cast(); + + auto firstShape = firstType.getShape(); + int numOperands = op.getNumOperands(); + for (int i = 1; i < numOperands; i++) { + auto secondType = op.getOperand(i)->getType().cast(); + + if (firstType.getRank() != secondType.getRank()) { + return op.emitOpError( + llvm::formatv("operands (0) and ({0}) do not match rank.", i)); + } + + auto secondShape = secondType.getShape(); + for (int d = 0; d < firstType.getRank(); ++d) { + if (firstShape[d] != secondShape[d] && d != op.dimension()) { + return op.emitOpError(llvm::formatv( + "operands (0) and ({0}) non-concat dimensions do not match " + "({1}) != ({2}).", + i, llvm::make_range(firstShape.begin(), firstShape.end()), + llvm::make_range(secondShape.begin(), secondShape.end()))); + } + } + } + return success(); } //===----------------------------------------------------------------------===// @@ -225,6 +439,89 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SelectOp op) { + auto onTrueType = op.on_true()->getType().cast(); + auto onFalseType = op.on_false()->getType().cast(); + + if (onTrueType != onFalseType) { + return op.emitOpError( + llvm::formatv("on_true type ({0}) does not match on_false type ({1})", + onTrueType, onFalseType)); + } + + auto predType = op.pred()->getType().cast(); + auto predShape = predType.getShape(); + auto predRank = predType.getRank(); + auto selectShape = onTrueType.getShape(); + + if (predRank != 0 && predShape != selectShape) { + return op.emitOpError(llvm::formatv( + "pred shape ([{0}]) is not scalar and does not match operand shapes " + "([{1}])", + llvm::make_range(predShape.begin(), predShape.end()), + llvm::make_range(selectShape.begin(), selectShape.end()))); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(PadOp op) { + auto input_type = op.operand()->getType().cast(); + auto pad_type = op.padding_value()->getType().cast(); + + if (pad_type.getRank() != 0) { + return op.emitOpError( + llvm::formatv("padding value type should be a rank-0 " + "tensor, is rank {0}", + pad_type.getRank())); + } + + const auto& padding_low = op.edge_padding_low(); + if (padding_low.getType().getNumElements() != input_type.getRank()) { + return op.emitOpError(llvm::formatv( + "edge_padding_low length ({0}) must match operand rank ({1}).", + padding_low.getType().getNumElements(), input_type.getRank())); + } + + const auto& padding_high = op.edge_padding_high(); + if (padding_high.getType().getNumElements() != input_type.getRank()) { + return op.emitOpError(llvm::formatv( + "edge_padding_high length ({0}) must match operand rank ({1}).", + padding_high.getType().getNumElements(), input_type.getRank())); + } + + auto input_shape = input_type.getShape(); + auto output_shape = + op.getResult()->getType().cast().getShape(); + if (input_shape.size() != output_shape.size()) { + return op.emitOpError( + llvm::formatv("Operand rank ({0}) and result rank({0}) should match", + input_shape.size(), output_shape.size())); + } + + for (int i = 0, e = input_shape.size(); i < e; i++) { + int expected_output = input_shape[i] + + padding_low.getValue(i).getInt() + + padding_high.getValue(i).getInt(); + if (expected_output != output_shape[i]) { + return op.emitOpError( + llvm::formatv("Expected output shape ({0}) and " + "output shape ({1}) should match.", + expected_output, output_shape[i])); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// @@ -237,3 +534,51 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { } return getOperand(); } + +static LogicalResult Verify(TransposeOp op) { + auto permutationType = op.permutation().getType(); + auto permutationRank = permutationType.getRank(); + if (permutationRank != 1) { + return op.emitOpError(llvm::formatv( + "permutation has rank {0} instead of rank 1", permutationRank)); + } + + auto operandType = op.operand()->getType().cast(); + auto operandRank = operandType.getRank(); + auto permutationSize = permutationType.getNumElements(); + if (permutationSize != operandRank) { + return op.emitOpError(llvm::formatv( + "permutation size ({0}) does not match operand rank ({1})", + permutationSize, operandRank)); + } + + auto resultType = op.getResult()->getType().cast(); + auto resultRank = resultType.getRank(); + if (resultRank != operandRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) does not match operand rank ({1})", + resultRank, operandRank)); + } + + auto resultShape = resultType.getShape(); + + auto expectedShape = SmallVector(operandRank); + for (int i = 0; i != operandRank; ++i) { + auto permutedDim = op.permutation().getValue(i).getInt(); + expectedShape[i] = operandType.getDimSize(permutedDim); + } + + if (resultShape != llvm::makeArrayRef(expectedShape)) { + return op.emitOpError(llvm::formatv( + "result shape is [{0}" + "] instead of [{1}" + "]", + llvm::make_range(resultShape.begin(), resultShape.end()), + llvm::make_range(expectedShape.begin(), expectedShape.end()))); + } + + return success(); +} + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 3260a829734..09a9cec968f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -32,6 +32,8 @@ limitations under the License. namespace mlir { class OpBuilder; +#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.h.inc" + namespace xla_hlo { class XlaHloDialect : public Dialect { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 5827225988b..e6efdc82d9d 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -38,6 +38,10 @@ class HLO_Op traits> : Op { // Whether this operation has a custom conversion to HLO or not. bit hasCustomHLOConverter = 0b0; + + // TODO(b/129012527) Much of this custom verification should be expressed as + // type constraints. + let verifier = [{ return Verify(*this); }]; } //===----------------------------------------------------------------------===// @@ -140,9 +144,9 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", class HLO_BinaryElementwiseOp traits> : HLO_Op { let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + BroadcastDimAttr:$broadcast_dimensions ); let results = (outs HLO_Tensor); let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }]; @@ -196,11 +200,13 @@ def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { def HLO_ReduceOp: HLO_Op<"reduce", [ NoSideEffect, + SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceOp { let arguments = (ins - Variadic:$operands_and_init, - ElementsAttr:$dimensions + Variadic:$operands, + Variadic:$init_values, + I64ElementsAttr:$dimensions ); let results = (outs Variadic); @@ -241,10 +247,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_CompareOp { let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + BroadcastDimAttr:$broadcast_dimensions, + HLO_ComparisonDirectionAttr:$comparison_direction ); let results = (outs HLO_PredTensor); } @@ -256,11 +262,12 @@ def HLO_CompareOp: HLO_Op<"compare", def HLO_SliceOp: HLO_Op< "slice", [NoSideEffect, SameOperandsAndResultElementType, - AllTypesMatch<["start_indices", "limit_indices"]>]> { - let arguments = ( - ins HLO_Tensor:$operand, - ElementsAttr:$start_indices, - ElementsAttr:$limit_indices + AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$start_indices, + I64ElementsAttr:$limit_indices, + I64ElementsAttr:$strides ); let results = (outs HLO_Tensor); @@ -309,58 +316,10 @@ def HLO_BroadcastOp : HLO_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { let arguments = (ins HLO_Tensor:$operand, - ElementsAttr:$broadcast_sizes + I64ElementsAttr:$broadcast_sizes ); let results = (outs HLO_Tensor); - - // TODO(b/129012527) These should be expressed as type constraints. - let verifier = [{ - auto sizes = broadcast_sizes().dyn_cast(); - if (!sizes) { - return emitOpError(llvm::formatv( - "broadcast_sizes must be a DenseIntElementsAttr; got {0}", - broadcast_sizes())); - } - auto sizesType = sizes.getType().cast(); - auto sizesRank = sizesType.getRank(); - if (sizesRank != 1) { - return emitOpError(llvm::formatv( - "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); - } - - auto resultType = getResult()->getType().cast(); - auto resultRank = resultType.getRank(); - auto operandType = operand()->getType().cast(); - auto operandRank = operandType.getRank(); - auto sizesSize = sizesType.getNumElements(); - auto expectedRank = operandRank + sizesSize; - - if (resultRank != expectedRank) { - return emitOpError( - llvm::formatv("result rank ({0}) does not match operand rank " - "({2}) plus size of broadcast_sizes ({3})", - resultRank, operandRank, sizesSize)); - } - - llvm::SmallVector expectedShape(sizes.getValues()); - - auto operandShape = operandType.getShape(); - expectedShape.insert(expectedShape.end(), operandShape.begin(), - operandShape.end()); - - auto resultShape = resultType.getShape(); - if (resultShape != llvm::makeArrayRef(expectedShape)) { - return emitOpError(llvm::formatv( - "result has shape [{0}" - "] instead of [{1}" - "]", - llvm::make_range(resultShape.begin(), resultShape.end()), - llvm::make_range(expectedShape.begin(), expectedShape.end()))); - } - - return success(); - }]; } def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", @@ -372,72 +331,6 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let results = (outs HLO_Tensor); - // TODO(b/129012527) These should be expressed as type constraints. - let verifier = [{ - auto operandType = operand()->getType().cast(); - auto operandRank = operandType.getRank(); - if (!broadcast_dimensions()) { - if (operandRank == 0) { - return success(); - } - return emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - - auto dimensions = broadcast_dimensions()->dyn_cast(); - if (!dimensions) { - return emitOpError(llvm::formatv( - "broadcast_sizes must be a DenseIntElementsAttr; got {0}", - broadcast_dimensions())); - } - - auto dimensionsType = broadcast_dimensions()->getType().cast(); - auto dimensionsRank = dimensionsType.getRank(); - if (dimensionsRank != 1) { - return emitOpError( - llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1", - dimensionsRank)); - } - - auto dimensionsSize = dimensionsType.getNumElements(); - if (dimensionsSize != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - dimensionsSize, operandRank)); - } - - auto resultType = getResult()->getType().cast(); - auto resultRank = resultType.getRank(); - if (resultRank < operandRank) { - return emitOpError( - llvm::formatv("result rank ({0}) is less than operand rank ({1})", - resultRank, operandRank)); - } - - for (int i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions.getValue(i); - if (dimIndex >= resultRank) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result result with rank {1}", - dimIndex, resultRank)); - } - - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - - return success(); - }]; - // TODO(b/130357376): One of the arguments comes from the new shape, which is // not handled by the codegen. let hasCustomHLOConverter = 1; @@ -452,74 +345,19 @@ def HLO_ClampOp : HLO_Op<"clamp", ); let results = (outs HLO_Tensor); - - // TODO(b/129012527) These should be expressed as type constraints. - let verifier = [{ - auto operandType = operand()->getType().cast(); - auto operandShape = operandType.getShape(); - auto minType = min()->getType().cast(); - - auto minShape = minType.getShape(); - if (minShape != operandShape && minType.getRank() != 0) { - return emitOpError(llvm::formatv( - "min shape [{0}" - "] is not scalar and does not match operand shape [{1}" - "]", - llvm::make_range(minShape.begin(), minShape.end()), - llvm::make_range(operandShape.begin(), operandShape.end()))); - } - - auto maxType = max()->getType().cast(); - auto maxShape = maxType.getShape(); - if (maxShape != operandShape && maxType.getRank() != 0) { - return emitOpError(llvm::formatv( - "max shape [{0}" - "] is not scalar and does not match operand shape [{1}" - "]", - llvm::make_range(maxShape.begin(), maxShape.end()), - llvm::make_range(operandShape.begin(), operandShape.end()))); - } - - return success(); - }]; } def HLO_ConcatenateOp : HLO_Op<"concatenate", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ConcatenateOp { + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ConcatenateOp { - let arguments = ( - ins Variadic:$val, - I64Attr: $dimension - ); + let arguments = (ins + Variadic:$val, + I64Attr: $dimension + ); - let verifier = [{ - auto firstType = getOperand(0)->getType().cast(); + let results = (outs HLO_Tensor); - auto firstShape = firstType.getShape(); - int numOperands = getNumOperands(); - for (int i = 1; i < numOperands; i++) { - auto secondType = getOperand(i)->getType().cast(); - - if (firstType.getRank() != secondType.getRank()) { - return emitOpError( - llvm::formatv("operands (0) and ({0}) do not match rank.", i)); - } - - auto secondShape = secondType.getShape(); - for (int d = 0; d < firstType.getRank(); ++d) { - if (firstShape[d] != secondShape[d] && d != dimension()) { - return emitOpError(llvm::formatv( - "operands (0) and ({0}) non-concat dimensions do not match " - "({1}) != ({2}).", - i, llvm::make_range(firstShape.begin(), firstShape.end()), - llvm::make_range(secondShape.begin(), secondShape.end()))); - } - } - } - return success(); - }]; - - let results = (outs HLO_Tensor); + let hasFolder = 1; // TODO(b/129422361) ConcatOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -555,22 +393,41 @@ def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let arguments = ( - ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - HLO_PrecisionConfigAttr:$precision_config - ); + ins HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + HLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs HLO_Tensor); +} + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", ElementsAttr>] > { + let description = "Structure of dimension information for dot product"; +} + +def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, + HLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs HLO_Tensor); } def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { - let arguments = ( - ins HLO_Tensor:$operand, - HLO_IntTensor:$start_indices, - I64Attr: $index_vector_dim, - ElementsAttr: $offset_dims, - ElementsAttr: $slice_sizes, - ElementsAttr: $collapsed_slice_dims, - ElementsAttr: $start_index_map + let arguments = (ins + HLO_Tensor:$operand, + HLO_IntTensor:$start_indices, + I64Attr:$index_vector_dim, + I64ElementsAttr:$offset_dims, + I64ElementsAttr:$slice_sizes, + I64ElementsAttr:$collapsed_slice_dims, + I64ElementsAttr:$start_index_map ); let results = (outs HLO_Tensor); @@ -602,41 +459,13 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect]>, BASE_HLO_SelectOp { ); let results = (outs HLO_Tensor); - - // TODO(b/129012527) These should be expressed as type constraints. - let verifier = [{ - auto onTrueType = on_true()->getType().cast(); - auto onFalseType = on_false()->getType().cast(); - - if (onTrueType != onFalseType) { - return emitOpError( - llvm::formatv("on_true type ({0}) does not match on_false type ({1})", - onTrueType, onFalseType)); - } - - auto predType = pred()->getType().cast(); - auto predShape = predType.getShape(); - auto predRank = predType.getRank(); - auto selectShape = onTrueType.getShape(); - - if (predRank != 0 && predShape != selectShape) { - return emitOpError(llvm::formatv( - "pred shape ([{0}" - "]) is not scalar and does not match operand shapes ([{1}" - "])", - llvm::make_range(predShape.begin(), predShape.end()), - llvm::make_range(selectShape.begin(), selectShape.end()))); - } - - return success(); - }]; } def HLO_ReverseOp: HLO_Op<"reverse", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReverseOp { let arguments = (ins HLO_Tensor:$operand, - ElementsAttr:$dimensions + I64ElementsAttr:$dimensions ); let results = (outs HLO_Tensor); @@ -650,9 +479,9 @@ def HLO_PadOp: HLO_Op<"pad", let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$padding_value, - ElementsAttr: $edge_padding_low, - ElementsAttr: $edge_padding_high, - ElementsAttr: $interior_padding + I64ElementsAttr: $edge_padding_low, + I64ElementsAttr: $edge_padding_high, + I64ElementsAttr: $interior_padding ); let results = (outs HLO_Tensor); @@ -661,51 +490,6 @@ def HLO_PadOp: HLO_Op<"pad", Pads the `operand` according to TBD. }]; - let verifier = [{ - auto input_type = operand()->getType().cast(); - auto pad_type = padding_value()->getType().cast(); - - if (pad_type.getRank() != 0) { - return emitOpError(llvm::formatv("padding value type should be a rank-0 " - "tensor, is rank {0}", pad_type.getRank())); - } - - const auto& padding_low = edge_padding_low(); - if (padding_low.getType().getNumElements() != input_type.getRank()) { - return emitOpError(llvm::formatv( - "edge_padding_low length ({0}) must match operand rank ({1}).", - padding_low.getType().getNumElements(), input_type.getRank())); - } - - const auto& padding_high = edge_padding_high(); - if (padding_high.getType().getNumElements() != input_type.getRank()) { - return emitOpError(llvm::formatv( - "edge_padding_high length ({0}) must match operand rank ({1}).", - padding_high.getType().getNumElements(), input_type.getRank())); - } - - auto input_shape = input_type.getShape(); - auto output_shape = getResult()->getType().cast().getShape(); - if (input_shape.size() != output_shape.size()) { - return emitOpError(llvm::formatv( - "Operand rank ({0}) and result rank({0}) should match", - input_shape.size(), output_shape.size())); - } - - for (int i = 0, e = input_shape.size(); i < e; i++) { - int expected_output = input_shape[i] - + padding_low.getValue(i).getInt() - + padding_high.getValue(i).getInt(); - if (expected_output != output_shape[i]) { - return emitOpError(llvm::formatv("Expected output shape ({0}) and " - "output shape ({1}) should match.", - expected_output, output_shape[i])); - } - } - - return success(); - }]; - // TODO(b/129422361): PadOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; } @@ -714,63 +498,45 @@ def HLO_TransposeOp: HLO_Op<"transpose", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp { let arguments = (ins HLO_Tensor:$operand, - ElementsAttr:$permutation + I64ElementsAttr:$permutation ); let results = (outs HLO_Tensor); let hasFolder = 1; +} - // TODO(b/129012527) These should be expressed as type constraints. - let verifier = [{ - if (!permutation().isa()) { - return emitOpError( - llvm::formatv("permutation must be a DenseIntElementsAttr; got {0}", - permutation())); - } +def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ + NoSideEffect, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"ReturnOp"> + ]>, BASE_HLO_ReduceWindowOp { - auto permutationType = permutation().getType().cast(); - auto permutationRank = permutationType.getRank(); - if (permutationRank != 1) { - return emitOpError(llvm::formatv( - "permutation has rank {0} instead of rank 1", permutationRank)); - } + // TODO(hinsu): Verify that padding attribute is 2-d and the remaining + // attributes are 1-d. Attributes' leading dimension should match rank of the + // inputs. + let arguments = (ins + Variadic:$operands, + Variadic:$init_values, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the input dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding + ); - auto operandType = operand()->getType().cast(); - auto operandRank = operandType.getRank(); - auto permutationSize = permutationType.getNumElements(); - if (permutationSize != operandRank) { - return emitOpError(llvm::formatv( - "permutation size ({0}) does not match operand rank ({1})", - permutationSize, operandRank)); - } + let results = (outs Variadic); - auto resultType = getResult()->getType().cast(); - auto resultRank = resultType.getRank(); - if (resultRank != operandRank) { - return emitOpError( - llvm::formatv("result rank ({0}) does not match operand rank ({1})", - resultRank, operandRank)); - } + // TODO(hinsu): Verify that the attached body arguments and results are + // compatible with reduce op's operands. + let regions = (region SizedRegion<1>:$body); - auto resultShape = resultType.getShape(); + // TODO(b/129422361): ReduceWindowOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; - auto expectedShape = SmallVector(operandRank); - for (int i = 0; i != operandRank; ++i) { - auto permutedDim = permutation().getValue(i).getInt(); - expectedShape[i] = operandType.getDimSize(permutedDim); - } - - if (resultShape != llvm::makeArrayRef(expectedShape)) { - return emitOpError(llvm::formatv( - "result shape is [{0}" - "] instead of [{1}" - "]", - llvm::make_range(resultShape.begin(), resultShape.end()), - llvm::make_range(expectedShape.begin(), expectedShape.end()))); - } - - return success(); - }]; + // TODO(hinsu): Implement custom printer and parser. } def HLO_ReturnOp : HLO_Op<"return", [Terminator]> { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 28d6efd0aad..6623c21dcb8 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -165,7 +165,7 @@ class BASE_HLO_TanhOp { // smaller rank shape is broadcast into a larger rank shape. For example, // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means // matching the matrix to dimensions 1 and 2 of the cuboid. -def BroadcastDimAttr : OptionalAttr; +def BroadcastDimAttr : OptionalAttr; class BASE_HLO_AddOp { string summary = "Addition operator"; @@ -244,10 +244,6 @@ class BASE_HLO_AndOp { }]; } -//===----------------------------------------------------------------------===// -// XLA control flow op definitions. -//===----------------------------------------------------------------------===// - class BASE_HLO_ReduceOp { string summary = "Reduce operator"; @@ -259,6 +255,17 @@ class BASE_HLO_ReduceOp { }]; } +class BASE_HLO_ReduceWindowOp { + string summary = "ReduceWindow operator"; + + string description = [{ + Returns the result of executing a reduction function over all elements in + each window of one or more arrays in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reducewindow. + }]; +} + //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// @@ -454,6 +461,17 @@ class BASE_HLO_DotOp { }]; } +class BASE_HLO_DotGeneralOp { + string summary = "General Dot operator"; + string description = [{ + Performs general dot products between vectors, vector/matrix and + matrix/matrix multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. + }]; +} + + class BASE_HLO_GatherOp{ string summary = "Gather operator"; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 4ef40c4c69f..597e5b3671b 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -133,12 +133,13 @@ def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp; // TODO(b/139813999): specify required function signature in a type-safe way. def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp { let arguments = (ins - Variadic:$operands_and_init, + Variadic:$operands, + Variadic:$init_values, Variadic:$out, // TODO(hinsu): Attach computation as a region similar to the // xla_hlo.reduce op. SymbolRefAttr:$computation, - ElementsAttr:$dimensions + I64ElementsAttr:$dimensions ); } //===----------------------------------------------------------------------===// @@ -175,12 +176,13 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { def LHLO_SliceOp: LHLO_Op< "slice", - [AllTypesMatch<["start_indices", "limit_indices"]>]> { + [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let arguments = (ins LHLO_Buffer:$operand, LHLO_Buffer:$output, - ElementsAttr:$start_indices, - ElementsAttr:$limit_indices + I64ElementsAttr:$start_indices, + I64ElementsAttr:$limit_indices, + I64ElementsAttr:$strides ); } @@ -217,7 +219,7 @@ def LHLO_BroadcastOp : LHLO_Op<"broadcast", let arguments = (ins LHLO_Buffer:$operand, LHLO_Buffer:$output, - ElementsAttr:$broadcast_sizes + I64ElementsAttr:$broadcast_sizes ); } @@ -243,7 +245,7 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { let arguments = (ins Variadic:$val, LHLO_Buffer:$output, - I64Attr: $dimension + I64Attr:$dimension ); } @@ -268,11 +270,11 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { let arguments = (ins LHLO_Buffer:$operand, LHLO_IntBuffer:$start_indices, - I64Attr: $index_vector_dim, - ElementsAttr: $offset_dims, - ElementsAttr: $slice_sizes, - ElementsAttr: $collapsed_slice_dims, - ElementsAttr: $start_index_map, + I64Attr:$index_vector_dim, + I64ElementsAttr:$offset_dims, + I64ElementsAttr:$slice_sizes, + I64ElementsAttr:$collapsed_slice_dims, + I64ElementsAttr:$start_index_map, LHLO_Buffer:$output ); } @@ -297,7 +299,7 @@ def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { let arguments = (ins LHLO_Buffer:$operand, - ElementsAttr:$dimensions, + I64ElementsAttr:$dimensions, LHLO_Buffer:$output ); } @@ -306,17 +308,17 @@ def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { let arguments = (ins LHLO_Buffer:$operand, LHLO_Buffer:$padding_value, - ElementsAttr: $edge_padding_low, - ElementsAttr: $edge_padding_high, - ElementsAttr: $interior_padding, - LHLO_Buffer: $output + I64ElementsAttr:$edge_padding_low, + I64ElementsAttr:$edge_padding_high, + I64ElementsAttr:$interior_padding, + LHLO_Buffer:$output ); } def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { let arguments = (ins LHLO_Buffer:$operand, - ElementsAttr:$permutation, + I64ElementsAttr:$permutation, LHLO_Buffer:$output ); } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 230044d538b..c4008133b0c 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -27,42 +27,50 @@ limitations under the License. #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -using tensorflow::int64; +using ::tensorflow::int16; +using ::tensorflow::int32; +using ::tensorflow::int64; +using ::tensorflow::int8; +using ::tensorflow::uint16; +using ::tensorflow::uint32; +using ::tensorflow::uint64; +using ::tensorflow::uint8; static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) { auto values = attr.getValues(); return {values.begin(), values.end()}; } -// Converts the broadcast_dimensions attribute into a span of dimension numbers -// (empty if the attribute is absent). +// Converts the broadcast_dimensions attribute into a vector of dimension +// numbers (empty if the attribute is absent). static std::vector Convert_broadcast_dimensions( - llvm::Optional broadcast_dimensions) { + llvm::Optional broadcast_dimensions) { if (!broadcast_dimensions.hasValue()) return {}; - return ConvertDenseIntAttr( - broadcast_dimensions->cast()); + return ConvertDenseIntAttr(*broadcast_dimensions); } -// Converts the broadcast_sizes attribute into a span of dimension sizes. +// Converts the broadcast_sizes attribute into a vector of dimension sizes. static std::vector Convert_broadcast_sizes( - mlir::ElementsAttr broadcast_sizes) { - return ConvertDenseIntAttr( - broadcast_sizes.cast()); + mlir::DenseIntElementsAttr broadcast_sizes) { + return ConvertDenseIntAttr(broadcast_sizes); } -static std::vector Convert_permutation(mlir::ElementsAttr permutation) { - return ConvertDenseIntAttr(permutation.cast()); +static std::vector Convert_permutation( + mlir::DenseIntElementsAttr permutation) { + return ConvertDenseIntAttr(permutation); } // Converts the precision config array of strings attribute into the @@ -90,6 +98,45 @@ static std::unique_ptr Convert_precision_config( return precision_config; } +static xla::DotDimensionNumbers Convert_dot_dimension_numbers( + mlir::xla_hlo::DotDimensionNumbers dot_dimension_numbers_attr) { + xla::DotDimensionNumbers dot_dimension_numbers; + + auto rhs_contracting_dimensions = + dot_dimension_numbers_attr.rhs_contracting_dimensions() + .cast(); + auto lhs_contracting_dimensions = + dot_dimension_numbers_attr.lhs_contracting_dimensions() + .cast(); + auto rhs_batch_dimensions = + dot_dimension_numbers_attr.rhs_batching_dimensions() + .cast(); + auto lhs_batch_dimensions = + dot_dimension_numbers_attr.lhs_batching_dimensions() + .cast(); + + for (auto val : rhs_contracting_dimensions) { + dot_dimension_numbers.add_rhs_contracting_dimensions( + val.getLimitedValue(UINT64_MAX)); + } + for (auto val : lhs_contracting_dimensions) { + dot_dimension_numbers.add_lhs_contracting_dimensions( + val.getLimitedValue(UINT64_MAX)); + } + + for (auto val : rhs_batch_dimensions) { + dot_dimension_numbers.add_rhs_batch_dimensions( + val.getLimitedValue(UINT64_MAX)); + } + + for (auto val : lhs_batch_dimensions) { + dot_dimension_numbers.add_lhs_batch_dimensions( + val.getLimitedValue(UINT64_MAX)); + } + + return dot_dimension_numbers; +} + // Converts the comparison_direction string attribute into the XLA enum. The // string is assumed to correspond to exactly one of the allowed strings // representing the enum. This should have been checked in the op verify method. @@ -132,13 +179,47 @@ static double ConvertAPFloat(llvm::APFloat value) { namespace mlir { namespace { +StatusOr CreateLiteralFromAttr(Type type, ElementsAttr attr) { + xla::Shape shape = xla::TypeToShape(type); + +#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ + case xla_type: { \ + xla::Array source_data(shape.dimensions()); \ + source_data.SetValues(attr.getValues()); \ + return xla::LiteralUtil::CreateFromArray(source_data); \ + } + + switch (shape.element_type()) { + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64) + // TODO(b/130356985): Update once MLIR supports unsigned integers. + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64) + default: + return tensorflow::errors::Internal(absl::StrCat( + "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type()))); + } +#undef ELEMENTS_ATTR_TO_LITERAL +} + class ConvertToHloModule { public: using ValueLoweringMap = llvm::DenseMap; using FunctionLoweringMap = llvm::DenseMap; - explicit ConvertToHloModule(mlir::ModuleOp module) - : module_(module), module_builder_("main") {} + explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args, + bool always_return_tuple) + : module_(module), + module_builder_("main"), + use_tuple_args_(use_tuple_args), + always_return_tuple_(always_return_tuple) {} // Perform the lowering to XLA. This function returns failure if an error was // encountered. @@ -160,6 +241,9 @@ class ConvertToHloModule { } private: + LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering); + // The module being lowered. mlir::ModuleOp module_; @@ -168,26 +252,37 @@ class ConvertToHloModule { // Map between function and lowered computation. FunctionLoweringMap lowered_computation_; + + // Whether the entry function should take a single tuple as input. + bool use_tuple_args_; + + // Whether to always return a tuple. + bool always_return_tuple_; }; -LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, - ConvertToHloModule::FunctionLoweringMap* function_lowering, - ConvertToHloModule::ValueLoweringMap* value_lowering) { - if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success(); - - // TODO(riverriddle) We currently don't support lowering constant operations. - if (isa(inst)) { - inst->emitError("unable to lower 'xla_hlo.constant' operation"); - return failure(); - } +LogicalResult ConvertToHloModule::Lower( + mlir::Operation* inst, xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering) { + if (succeeded(ExportXlaOperator(inst, value_lowering))) return success(); auto& value_map = *value_lowering; + ElementsAttr const_attr; + // TODO(jpienaar): This doesn't support layouts yet. + if (matchPattern(inst, m_Constant(&const_attr))) { + auto literal_or = + CreateLiteralFromAttr(*inst->result_type_begin(), const_attr); + if (!literal_or.ok()) return inst->emitError("unsupported elemental type"); + value_map[inst->getResult(0)] = + xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + return success(); + } + if (auto ret = dyn_cast(inst)) { // Construct the return value for the function. If there are multiple // values returned, then create a tuple, else return value directly. xla::XlaOp return_value; unsigned num_return_values = ret.getNumOperands(); - if (num_return_values > 1) { + if (always_return_tuple_ || num_return_values > 1) { std::vector returns(num_return_values); for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) { returns[i] = value_map[ret.getOperand(i)]; @@ -205,7 +300,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, return failure(); } auto f = inst->getParentOfType(); - (*function_lowering)[f] = std::move(computation_or.ValueOrDie()); + lowered_computation_[f] = std::move(computation_or.ValueOrDie()); return success(); } inst->emitError("unable to lower operation of type '" + @@ -228,28 +323,42 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { // Mapping from the Value to lowered XlaOp. The code below lowers in // program order and will fail if an operand is unseen. This can be improved. ValueLoweringMap lowering; - for (auto& bb : f) { - int num = 0; - for (auto& arg : bb.getArguments()) { + auto& bb = f.front(); + + // If using tuples as input, then there is only one input + // parameter that is a tuple. + if (use_tuple_args_) { + std::vector arg_shapes; + arg_shapes.reserve(bb.getNumArguments()); + for (auto& arg : bb.getArguments()) + arg_shapes.push_back(xla::TypeToShape(arg->getType())); + xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); + auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple"); + for (auto& it : llvm::enumerate(bb.getArguments())) { + lowering[it.value()] = xla::GetTupleElement(tuple, it.index()); + } + } else { + for (auto& it : llvm::enumerate(bb.getArguments())) { + auto* arg = it.value(); + auto num = it.index(); xla::Shape shape = xla::TypeToShape(arg->getType()); lowering[arg] = xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num)); - ++num; } - - for (auto& inst : bb) - if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering))) - return failure(); } + for (auto& inst : bb) + if (failed(Lower(&inst, &builder, &lowering))) return failure(); + return success(); } } // namespace -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) { +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args, bool always_return_tuple) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - ConvertToHloModule converter(module); + ConvertToHloModule converter(module, use_tuple_args, always_return_tuple); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index b16636f039c..24d20fe7017 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -23,8 +23,12 @@ limitations under the License. namespace mlir { -// Converts a MLIR module in HLO dialect into a HloModuleProto. -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto); +// Converts a MLIR module in HLO dialect into a HloModuleProto. If +// use_tuple_args is set, then functions will have a single tuple as input. If +// always_return_tuple is set, then functions will return tuple whether or not +// there is only one result. +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args, bool always_return_tuple); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 6aecf70b385..00b7cd06a1e 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" @@ -24,24 +25,19 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include "mlir/Support/STLExtras.h" // TF:local_config_mlir #include "mlir/TableGen/Operator.h" // TF:local_config_mlir -using llvm::dyn_cast; -using llvm::LessRecord; using llvm::raw_ostream; -using llvm::Record; using llvm::RecordKeeper; using llvm::StringRef; +using mlir::interleaveComma; +using mlir::tblgen::NamedAttribute; +using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; -// Returns the builder function name for the given op definition. -// E.g., AddOp -> CreateAddOp -static inline std::string GetOperatorBuilderName(StringRef op_name) { - return "Create" + op_name.str(); -} - -static std::string GetConversionFunction( - mlir::tblgen::NamedAttribute named_attr) { +static std::string GetDefaultAttrExport( + const mlir::tblgen::NamedAttribute& named_attr) { auto storage_type = named_attr.attr.getStorageType(); // For some attribute types we have a general conversion, so use that. if (storage_type.endswith("IntegerAttr") || @@ -51,110 +47,50 @@ static std::string GetConversionFunction( return "Convert_" + named_attr.name.str(); } -using ArgumentName = std::string; -using ArgumentDeclaration = std::string; -using Argument = std::pair; -using ArgumentList = std::vector; - -static std::string BuildOperator(const Operator& op) { - std::stringstream os; - StringRef op_name = op.getCppClassName(); - std::string xla_op_name = op_name.drop_back(2).str(); - - // Signature. - os << "static xla::XlaOp " << GetOperatorBuilderName(op_name) - << "(mlir::xla_hlo::" << op_name.str() << " xla_op, " - << "llvm::DenseMap* " - "value_lowering) {\n"; - +static void BuildOperator(const Operator& op, raw_ostream* output) { + auto& os = *output; os << " auto& value_map = *value_lowering;\n" << " auto result = xla_op.getResult();\n"; - // Invoke the conversion function for each attribute. - for (const auto& named_attr : op.getAttributes()) { - os << " auto " << named_attr.name.str() << " = " - << GetConversionFunction(named_attr) << "(" - << "xla_op." << named_attr.name.str() << "());\n"; + // Build a conversion for each of the arguments. + int operand_number = 0; + for (int index : llvm::seq(0, op.getNumArgs())) { + auto arg = op.getArg(index); + + // Emit an argument for an operand. + if (auto* operand_cst = arg.dyn_cast()) { + // Handle a non-variadic operand. + if (!operand_cst->isVariadic()) { + os << "auto xla_arg_" << index << " = value_map[*xla_op.getODSOperands(" + << operand_number++ << ").begin()];\n"; + continue; + } + + // Otherwise, this is a varidiac operand list. + os << " std::vector xla_arg_" << index << ";" + << " for (auto operand : xla_op.getODSOperands(" << operand_number++ + << "))\n xla_arg_" << index << ".push_back(value_map[operand]);\n"; + continue; + } + + // Otherwise, this is an attribute. + auto named_attr = arg.get(); + os << "auto xla_arg_" << index << " = " << GetDefaultAttrExport(*named_attr) + << "(xla_op." << op.getArgName(index) << "());\n"; } // Assumes that the client builder method names closely follow the op names // in the dialect. For e.g., AddOp -> xla::Add method. - os << " auto xla_result = xla::" << xla_op_name << "("; - - int num_operands = op.getNumOperands(); - if (num_operands == 1) { - os << "value_map[xla_op.getOperand()]"; - } else { - for (auto i = 0; i < num_operands; i++) { - os << "value_map[xla_op.getOperand(" << i << ")]"; - if (i != num_operands - 1) { - os << ", "; - } - } - } - - for (const auto& named_attr : op.getAttributes()) { - os << ", Unwrap(" << named_attr.name.str() << ")"; - } + StringRef op_name = op.getCppClassName(); + os << " auto xla_result = xla::" << op_name.drop_back(2) << "("; + // Emit each of the arguments. + interleaveComma(llvm::seq(0, op.getNumArgs()), os, + [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; }); os << ");\n"; os << " value_map[result] = xla_result;\n"; - os << " return xla_result;\n"; - os << "}\n\n"; - return os.str(); -} - -// For each XLA op, emits a builder function that constructs the XLA op using -// the HLO client builder. -static void EmitOperatorBuilders(const RecordKeeper& record_keeper, - const std::vector& defs, - raw_ostream* ostream) { - raw_ostream& os = *ostream; - - for (const auto* def : defs) { - // Skip operations that have a custom converter. - if (def->getValueAsBit("hasCustomHLOConverter")) continue; - - Operator op(def); - os << BuildOperator(op); - } -} - -// Emits a builder function that returns the XlaOp object given a -// mlir::Operation. -// -// The signature of the function is: -// -// llvm::Optional -// mlir::CreateXlaOperator( -// mlir::Operation* op, -// llvm::DenseMap -// *value_lowering); -static void EmitBuilder(const std::vector& defs, - raw_ostream* ostream) { - raw_ostream& os = *ostream; - - // Signature - os << "llvm::Optional\n" - "mlir::CreateXlaOperator(mlir::Operation* op, " - "llvm::DenseMap " - "*value_lowering) {\n"; - - for (const auto* def : defs) { - // Skip operations that have a custom converter. - if (def->getValueAsBit("hasCustomHLOConverter")) continue; - - StringRef op_name = def->getName().drop_front(4); - - // Try to cast to each op and call the corresponding op builder. - os << " if (auto xla_op = llvm::dyn_cast(op))\n return " << GetOperatorBuilderName(op_name) - << "(xla_op, value_lowering);\n"; - } - - os << " return llvm::None;\n" - "}\n"; + os << " return mlir::success();\n"; } // The function below has a non-constant reference as that is required by LLVM's @@ -163,26 +99,27 @@ static void EmitBuilder(const std::vector& defs, static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { emitSourceFileHeader("MLIR XLA Builders", os); - // Retrieve all the definitions derived from HLO_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("HLO_Op"); - llvm::sort(defs, LessRecord()); + // Emit a function to generate an XLA operation for the operations with + // auto-generated builders. + os << "mlir::LogicalResult ExportXlaOperator(\n" + "mlir::Operation* op, llvm::DenseMap " + "*value_lowering) {\n"; - for (const auto* def : defs) { - // XLA ops in the .td file are expected to follow the naming convention: - // HLO_Op. - // The generated XLA op C++ class should be HLO::Op. - if (!def->getName().startswith("HLO_")) - PrintFatalError(def->getLoc(), - "unexpected op name format: 'HLO_' prefix missing"); - if (!def->getName().endswith("Op")) - PrintFatalError(def->getLoc(), - "unexpected op name format: 'Op' suffix missing"); + // Retrieve all the definitions derived from HLO_Op and sort by record name. + for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { + // Skip operations that have a custom exporter. + if (def->getValueAsBit("hasCustomHLOConverter")) continue; + Operator op(def); + + // Cast to the current operation and build the exporter. + os << " if (auto xla_op = llvm::dyn_cast(op)) {\n"; + BuildOperator(op, &os); + os << "}\n"; } - EmitOperatorBuilders(records, defs, &os); - os << "\n\n"; - EmitBuilder(defs, &os); - + os << " return mlir::failure();\n" + "}\n"; return false; } diff --git a/tensorflow/compiler/mlir/xla/tests/concatenate.mlir b/tensorflow/compiler/mlir/xla/tests/concatenate.mlir new file mode 100644 index 00000000000..b0f3ceeb59e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/concatenate.mlir @@ -0,0 +1,9 @@ +// RUN: tf-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: func @single_operand +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<1x2xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 0328761becc..4af6726c584 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -141,6 +141,53 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x return %0: tensor<1x2xi32> } +//===----------------------------------------------------------------------===// +// Concat op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @concat_v2 +func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + return %1 : tensor<6x3xf32> +} + +// CHECK-LABEL: func @concat_v2_neg_axis +func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + + %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + return %1 : tensor<6x3xf32> +} + +// CHECK-LABEL: func @concat_v2_1d_axis +func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { + // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + + %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> + return %1 : tensor<3x6xf32> +} + +// CHECK-LABEL: func @concat_v2_non_const_axis +func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor) -> tensor<3x6xf32> { + // CHECK: "tf.ConcatV2" + + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> + return %1 : tensor<3x6xf32> +} + +// CHECK-LABEL: func @concat_v2_unranked +func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "tf.ConcatV2" + + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) {N = 2 : i64} : (tensor<*xf32>, tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> +} + //===----------------------------------------------------------------------===// // Identity op legalizations. //===----------------------------------------------------------------------===// @@ -177,6 +224,37 @@ func @matmul_notranspose(%arg0: tensor<5x7xf32>, %arg1: tensor<7x11xf32>) -> ten return %0 : tensor<5x11xf32> } +//===----------------------------------------------------------------------===// +// MaxPool op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: maxpool_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: %[[INIT:.*]] = constant dense<-2147483648> : tensor + // CHECK: "xla_hlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: xla_hlo.max + // CHECK: xla_hlo.return + // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + return %0 : tensor<2x3x5x7xi32> +} + +//===----------------------------------------------------------------------===// +// Pack op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @pack +func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // CHECK: "xla_hlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: "xla_hlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: "xla_hlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + %0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// @@ -198,6 +276,64 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { return %0: tensor<1xi32> } +//===----------------------------------------------------------------------===// +// Softmax op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @simple_softmax +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) +func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: %[[NEG_INF:.*]] = constant dense<0xFF800000> : tensor + // CHECK: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor + + // Verify reduce op for max computation and its body. + // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]]) + // CHECK: xla_hlo.max + // CHECK: "xla_hlo.return" + // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> + + // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) + + // Verify reduce op for summation and its body. + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]]) + // CHECK: xla_hlo.add + // CHECK: "xla_hlo.return" + // CHECK: {dimensions = dense<1> : tensor<1xi64>} + + // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // return %[[RESULT]] + + %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %0: tensor<2x3xf32> +} + +// CHECK-LABEL: bf16_softmax +func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { + // Verify that conversion to f32 and then back to bf16 are introduced. + + // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32> + // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16> + + %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16> + return %0: tensor<2x3xbf16> +} + +// CHECK-LABEL: rank4_softmax +func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { + // Verify that reduce op dimensions and broadcast dimensions are correct. + + // CHECK: "xla_hlo.reduce" + // CHECK: dimensions = dense<3> + + // CHECK: "xla_hlo.reduce" + // CHECK: dimensions = dense<3> + + // CHECK: "xla_hlo.div"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> + return %0: tensor<2x3x4x5xf16> +} + //===----------------------------------------------------------------------===// // Unary op legalizations. //===----------------------------------------------------------------------===// @@ -229,3 +365,10 @@ func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// CHECK-LABEL: expand_dims +func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { + // CHECK: "xla_hlo.reshape"{{.*}} : (tensor<2xf32>) -> tensor<1x2xf32> + %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 4aee05a146c..146dc0a4b45 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -136,9 +136,8 @@ func @reduce_computation(%sum: memref<1xf32>, %element: memref<1xf32>) -> () { } // CHECK-LABEL: func @reduce_memref -func @reduce_memref(%input: memref<10xf32>, %out: memref<1xf32>) -> () { - "xla_lhlo.reduce"(%input, %out) {computation = @reduce_computation, - dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<1xf32>) -> () +func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { + "xla_lhlo.reduce"(%input, %init, %out) {computation = @reduce_computation, dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref, memref<1xf32>) -> () return } @@ -156,4 +155,4 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m "xla_lhlo.terminator"() : () -> () } ) : () -> () return -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 06c98fb39b0..4520f7615ca 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -50,30 +50,6 @@ func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // ----- -func @broadcast_nonint_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - -func @broadcast_splat_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<2.0> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - -func @broadcast_sparse_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<3xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> @@ -122,30 +98,6 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { // ----- -func @broadcast_in_dim_bad_nonint_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - -func @broadcast_in_dim_bad_splat_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2.0> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - -func @broadcast_in_dim_bad_sparse_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> - return %0 : tensor<1x2x3xi32> -} - -// ----- - func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> @@ -428,15 +380,15 @@ func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: // CHECK-LABEL: func @slice func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -444,7 +396,7 @@ func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // expected-error@+1 {{requires the same element type for all operands and results}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -458,30 +410,6 @@ func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // ----- -func @transpose_bad_permutations_float(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1.0, 0.0, 3.0, 2.0]> : tensor<4xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> - return %0: tensor<2x1x4x3xi32> -} - -// ----- - -func @transpose_bad_permutations_splat(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<2.0> : tensor<2xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> - return %0: tensor<2x1x4x3xi32> -} - -// ----- - -func @transpose_bad_permutations_sparse(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> - return %0: tensor<2x1x4x3xi32> -} - -// ----- - func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation has rank 2 instead of rank 1}} %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> @@ -519,3 +447,55 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> return %0: tuple, tensor<1x2xf32>> } + +// ----- + +func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { + // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> + return %0 : tuple, tensor, tensor> +} + +// ----- + +func @tuple_type_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor> { + // expected-error@+1 {{has return type tuple, tensor>, but expected tuple, tensor>}} + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> + return %0 : tuple, tensor> +} + +// ----- + +func @get_tuple_element(%arg0: tuple, tensor>) -> tensor { + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tensor { + // expected-error@+1 {{has return type tensor, but expected tensor}} + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor>) -> tensor { + // expected-error@+1 {{index 2 is out of bounds of operand with size 2}} + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @reduce_window +func @reduce_window(%arg0: tensor<4x4xi32>) -> tensor<2x2xi32> { + %cst = constant dense<0> : tensor + %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %6 = "xla_hlo.max"(%arg1, %arg2) : (tensor, tensor) -> tensor + "xla_hlo.return"(%6) : (tensor) -> () + }) {window_dimensions = dense<[2, 2]> : tensor<2xi64>, window_strides = dense<[2, 2]> : tensor<2xi64>, padding = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir index a77b90ca083..a457ba59e22 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir @@ -1,6 +1,12 @@ // RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args %s | FileCheck %s --check-prefix=TUPLE-ARG +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE-RET +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLES -// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] +// TUPLE-ARG-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> f32[4] +// TUPLE-RET-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4]) +// TUPLES-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> (f32[4]) func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/const.mlir b/tensorflow/compiler/mlir/xla/tests/translate/const.mlir new file mode 100644 index 00000000000..42d9c5dc963 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/const.mlir @@ -0,0 +1,30 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: ENTRY %main +func @main() -> tensor<2x2x1x1xf32> { + // CHECK: constant.{{.*}} = s64[] constant(1) + %cst = constant dense<1> : tensor + // CHECK: constant.{{.*}} = f32[2,2,1,1] + // CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } } + %cst_0 = constant dense< + [[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]] + > : tensor<2x2x1x1xf32> + + // CHECK: s32[1] constant({1}) + %cst_1 = constant dense<1> : tensor<1xi32> + + // CHECK: %[[C:.*]] = s32[] constant(1) + // CHECK: s32[10] broadcast(s32[] %[[C]]) + %cst_2 = constant dense<1> : tensor<10xi32> + + // CHECK: s32[4] constant({1, 2, 3, 4}) + %cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + + // CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } }) + %cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + + // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) + %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> + + return %cst_0 : tensor<2x2x1x1xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dot_general.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dot_general.hlotxt new file mode 100644 index 00000000000..25efcfd3e73 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/dot_general.hlotxt @@ -0,0 +1,25 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: @main +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] +ENTRY %main (Arg_0.1: f32[4, 1], Arg_1.2: f32[1, 4]) -> f32[] { + %Arg_0.1 = f32[4, 1] parameter(0) + %Arg_1.2 = f32[1, 4] parameter(1) + + // CHECK-NEXT: [[R0:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} + dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest} + + // CHECK-NEXT: [[R1:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} + dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default} + + // CHECK-NEXT: [[R2:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} + %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default} + + // TODO(b/129709049) consider making this default precision config inferred. + // CHECK-NEXT: [[R3:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: return [[R3]] + ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir new file mode 100644 index 00000000000..87817519870 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE + +// Test to verify that multiple result function with always emit return tuple +// does not result in nested tuples. + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: s32[4]) -> (s32[4], s32[1,2,3,4]) +// TUPLE-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (s32[4])) -> (s32[4], s32[1,2,3,4]) +func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) { + // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) + // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 830bcde5f04..fabdde69cf6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -15,13 +15,19 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. +#include #include #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" @@ -47,6 +53,53 @@ static size_t getFeatureDimension(StringAttr format, return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; } +static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank, + Builder *b) { + SmallVector index(attr.getType().getRank(), 0); + int64_t axis = attr.getValue(index).getInt(); + if (axis < 0) { + axis += rank; + } + return b->getI64IntegerAttr(axis); +} + +// Returns minimum value for the given int or float element type. +static ConstantOp GetMinValueForType(Type ty, Location loc, + PatternRewriter *rewriter) { + RankedTensorType scalar_ty = rewriter->getTensorType({}, ty); + + DenseElementsAttr attr; + if (auto float_ty = ty.dyn_cast_or_null()) { + APFloat neg_inf = + APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true); + attr = DenseElementsAttr::get(scalar_ty, neg_inf); + } else { + auto int_ty = ty.cast(); + APInt min_val = APInt::getSignedMinValue(int_ty.getWidth()); + attr = DenseElementsAttr::get(scalar_ty, min_val); + } + return rewriter->create(loc, attr); +} + +// Builds body for reduce op by using the using the template binary op as the +// reducer op. +template +static void BuildReduceBody(Type element_type, Region *body, + OpBuilder *builder) { + OpBuilder::InsertionGuard guard(*builder); + Block *block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = builder->getTensorType(/*shape=*/{}, element_type); + block->addArguments({type, type}); + + Location loc = body->getLoc(); + auto reducer = builder->create(loc, type, block->getArgument(0), + block->getArgument(1), + /*broadcast_dimensions=*/nullptr); + builder->create(loc, reducer.getResult()); +} + //===----------------------------------------------------------------------===// // BatchNorm op utilities. //===----------------------------------------------------------------------===// @@ -72,12 +125,15 @@ static bool hasValidBiasFeatureDimension(StringAttr format, Value *input, return biasType.getDimSize(0) == inputType.getDimSize(featureDim); } -/// Return a 1D ElementsAttr for the feature dimension of a BiasAdd. -static ElementsAttr getBiasFeatureDimension(Builder &b, StringAttr format, - Value *input) { - return b.getDenseIntElementsAttr( - b.getTensorType(1, b.getIntegerType(64)), - getFeatureDimension(format, input->getType().cast())); +/// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. +static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, + StringAttr format, + Value *input) { + auto inputType = input->getType().cast(); + size_t featureDim = getFeatureDimension(format, inputType); + RankedTensorType type = b.getTensorType(1, b.getIntegerType(64)); + return DenseIntElementsAttr::get(type, featureDim) + .cast(); } //===----------------------------------------------------------------------===// @@ -101,7 +157,8 @@ static ElementsAttr getSplat(Builder &b, Value *val, T constant) { return DenseElementsAttr::get(valType, elementAttr); } -static ElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, Value *y) { +static DenseIntElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, + Value *y) { TensorType xType = x->getType().dyn_cast(); TensorType yType = y->getType().dyn_cast(); if (xType == yType || !xType || !yType) return {}; @@ -126,23 +183,208 @@ static ElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, Value *y) { std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), maxRank - minRank); - return b.getDenseIntElementsAttr( - b.getTensorType({minRank}, b.getIntegerType(64)), broadcastDimensions); + RankedTensorType type = b.getTensorType({minRank}, b.getIntegerType(64)); + return DenseIntElementsAttr::get(type, broadcastDimensions) + .cast(); } +//===----------------------------------------------------------------------===// +// Softmax op utilities. +//===----------------------------------------------------------------------===// + +// Returns a 1-d i64 elements attribute populated with numbers from start to +// end, excluding. +static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, + Builder *builder) { + int size = end - start; + + SmallVector vals; + vals.resize(size); + std::iota(vals.begin(), vals.end(), start); + + TensorType ty = builder->getTensorType({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, vals) + .cast(); +} + +// Returns the type to use for accumulating the given type. +static Type GetAccumulationType(Type ty) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. + return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; +} + +//===----------------------------------------------------------------------===// +// Op converters. +//===----------------------------------------------------------------------===// + namespace mlir { namespace xla { namespace { + +// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with max as the reduction function. +// +// Sample result for VALID padding mode: +// +// %init = constant dense<...> : tensor +// %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"] +// {window_dimensions = ..., window_strides = ... } +// +class ConvertMaxPoolOp : public OpRewritePattern { + public: + explicit ConvertMaxPoolOp(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + PatternMatchResult matchAndRewrite(TF::MaxPoolOp op, + PatternRewriter &rewriter) const override { + // TODO(hinsu): Support 'SAME' padding mode. + if (op.padding() != "VALID") return matchFailure(); + + Type element_type = + op.input()->getType().cast().getElementType(); + if (!element_type.isIntOrFloat()) return matchFailure(); + Location loc = op.getLoc(); + ConstantOp init = GetMinValueForType(element_type, loc, &rewriter); + + auto get_elements_attr = [&](ArrayAttr attr) { + RankedTensorType ty = rewriter.getTensorType( + static_cast(attr.size()), rewriter.getIntegerType(64)); + return DenseElementsAttr::get(ty, attr.getValue()) + .cast(); + }; + + auto reduce = rewriter.create( + loc, op.getType(), op.input(), init.getResult(), + get_elements_attr(op.ksize()), get_elements_attr(op.strides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*paddings=*/DenseIntElementsAttr()); + BuildReduceBody(element_type, &reduce.body(), &rewriter); + + rewriter.replaceOp(op, reduce.getResult(0)); + return matchSuccess(); + } +}; + +// Converts Softmax op to HLO ops computing softmax with the following formula: +// +// softmax = div(exp(logits), sum(exp(logits))) +// +// Sample result with 2-d f16 inputs with B batches of with N elements each. +// +// // Subtract each element by their batches' max to improve numerical +// // stability. +// %neg_infinity = constant dense<0xFF800000> : tensor +// %max = "xla_hlo.reduce"(%input, %neg_infinity) ["xla_hlo.max"] +// {dimensions = 1} +// : (tensor, tensor<1xf16>) -> tensor +// %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0} +// : (tensor, tensor) -> tensor +// +// %exp = "xla_hlo.exp"(%sub) : (tensor) -> tensor +// +// // Cast to f32 to avoid precision loss in summation. +// %exp_f32 = "xla_hlo.convert"(%exp) : (tensor) -> tensor +// %zero = constant dense<0.000000e+00> : tensor +// %sum = "xla_hlo.reduce"(%exp, %zero) ["xla_hlo.add"] {dimensions = 1} +// : (tensor, tensor<1xf32>) -> tensor +// +// %sum_f16 = "xla_hlo.convert"(%sum) : (tensor) -> tensor +// %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0} +// : (tensor, tensor) -> tensor +// +class ConvertSoftmaxOp : public OpRewritePattern { + public: + explicit ConvertSoftmaxOp(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + PatternMatchResult matchAndRewrite(TF::SoftmaxOp op, + PatternRewriter &rewriter) const override { + Value *logits = op.logits(); + + // Softmax converter requires ranked type because the XLA reduce ops used + // while lowering requires dimensions attribute to reduce along. + RankedTensorType type = logits->getType().dyn_cast(); + if (!type) return matchFailure(); + int rank = type.getRank(); + + // Note that the TensorFlow Softmax op verifies that the input rank is + // greater than or equal to one so both of the following sequences are + // valid. + auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); + auto reduce_dim = GetI64ElementsAttrForSeq(rank - 1, rank, &rewriter); + Location loc = op.getLoc(); + + // Exponential of input values and then their sum can be very large here. + // Division with large denominator is numerically unstable. To improve + // numerical stability, subtract each batch with their max element so that + // the maximum input value is zero. It can be shown that softmax computed + // after adding or subtracting all inputs in a batch using a common value + // gives mathematically equivalent result. + Type element_type = type.getElementType(); + ArrayRef reduce_shape = type.getShape().drop_back(); + RankedTensorType reduce_out_type = + rewriter.getTensorType(reduce_shape, element_type); + auto init = GetMinValueForType(element_type, loc, &rewriter); + auto max_logits = rewriter.create( + loc, reduce_out_type, logits, init.getResult(), reduce_dim); + BuildReduceBody(element_type, &max_logits.body(), + &rewriter); + auto shifted_logits = rewriter.create( + loc, type, logits, max_logits.getResult(0), batch_dims); + + // Exponentiate the inputs. + Value *exp = rewriter.create(loc, type, shifted_logits); + + // Cast the exponentials to the appropriate accumulation type to avoid + // precision loss during summation. + Type sum_element_type = GetAccumulationType(element_type); + Type sum_type = rewriter.getTensorType(type.getShape(), sum_element_type); + auto casted_exp = rewriter.create(loc, sum_type, exp); + + // Compute summation of the exponentials. + init = rewriter.create( + loc, DenseElementsAttr::get(rewriter.getTensorType({}, element_type), + rewriter.getZeroAttr(element_type))); + Type sum_out_type = rewriter.getTensorType(reduce_shape, sum_element_type); + auto exp_sum = rewriter.create( + loc, sum_out_type, casted_exp.getResult(), init.getResult(), + reduce_dim); + BuildReduceBody(element_type, &exp_sum.body(), &rewriter); + Value *sum = exp_sum.getResult(0); + + // Convert the summation result back to the original element type and divide + // exponentials by the summations. + sum = rewriter.create(loc, reduce_out_type, sum); + rewriter.replaceOpWithNewOp(op, op.getType(), exp, sum, + batch_dims); + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" } // end anonymous namespace } // end namespace xla } // end namespace mlir void mlir::xla_hlo::legalizeTF(Operation *op) { - // Add the generated patterns to the list. + // Add lowering patterns to the list. OwningRewritePatternList patterns; xla::populateWithGenerated(op->getContext(), &patterns); + // Add patterns that lower some of the high level TensorFlow ops to lower + // level TensorFlow ops. So, we don't have to target all the TensorFlow ops + // here for lowering to HLO. + // + // TODO(b/140964075): Switch to DialectConversion to avoid premature lowering + // to lower level TensorFlow ops if we actually want to target the higher + // level TensorFlow op directly. + mlir::TF::PopulateLoweringTFPatterns(op->getContext(), &patterns); + + patterns.insert(op->getContext()); + patterns.insert(op->getContext()); + // Recursively applies rewrite patterns to nested operations. applyPatternsGreedily(op, patterns); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index d67f7b0c5fd..fe930e6095d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -21,7 +21,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; -def NullElementsAttr : NativeCodeCall<"ElementsAttr()">; +def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; //===----------------------------------------------------------------------===// // BatchNorm op patterns. @@ -87,6 +87,33 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], [TF_SubOp, HLO_SubOp]] in def : DirectBinaryPat; +//===----------------------------------------------------------------------===// +// Concat op patterns. +//===----------------------------------------------------------------------===// + +def OneElementAttrPred + : CPred<"$_self.cast().getType().getNumElements() == 1">; + +def OneElementAttr + : ElementsAttrBase, + "Scalar ElementsAttr">; + +def GetHLOAxisFromTFAxis : NativeCodeCall< + "GetHLOAxisFromTFAxis(" + "$0, (*$1.begin())->getType().cast().getRank(), " + "&$_builder)">; + +def HasRankedFirstOperand + : ConstraintgetType().isa()">>; + +// Here, we convert from TensorFlow axis format to HLO axis format which +// doesn't wrap around like TensorFlow and is always positive. For this +// conversion, use the first input to get inputs rank. Other inputs need not be +// ranked. +def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused), + (HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)), + [(HasRankedFirstOperand $inputs)]>; + //===----------------------------------------------------------------------===// // Identity op patterns. //===----------------------------------------------------------------------===// @@ -118,7 +145,7 @@ class ConstantSplat : NativeCodeCall< def : Pat<(TF_ReluOp AnyTensor:$input), (HLO_MaxOp (ConstantOp (ConstantSplat<"0"> $input)), $input, - (NullElementsAttr))>; + (NullDenseIntElementsAttr))>; def : Pat<(TF_Relu6Op AnyTensor:$input), (HLO_ClampOp (ConstantOp (ConstantSplat<"0"> $input)), $input, @@ -128,8 +155,7 @@ def : Pat<(TF_Relu6Op AnyTensor:$input), // Unary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_ReshapeOp:$res AnyStaticShapeTensor:$arg, $ignored), - (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; - -def : Pat<(TF_SqueezeOp AnyStaticShapeTensor:$arg, $ignored_dims), - (HLO_ReshapeOp $arg)>; +foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp] in { + def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored), + (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; +} diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 780e6380398..03f55f1a1cf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -111,26 +111,23 @@ Value* GetBinaryOp(Type element_type, Location loc, Value* lhs, } template -struct BinaryOpConverter : public RewritePattern { - explicit BinaryOpConverter(MLIRContext* context) - : RewritePattern(LhloOp::getOperationName(), {}, 1, context) {} +struct BinaryOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation* op, + PatternMatchResult matchAndRewrite(LhloOp op, PatternRewriter& rewriter) const override { - auto binary_op = cast(op); - - const auto& lhs = binary_op.lhs(); - const auto& rhs = binary_op.rhs(); + const auto& lhs = op.lhs(); + const auto& rhs = op.rhs(); const auto& lhs_type = lhs->getType().template cast(); const auto& rhs_type = rhs->getType().template cast(); const auto& element_type = lhs_type.getElementType(); if (lhs_type.getShape() != rhs_type.getShape()) { - return matchFailure(); + return this->matchFailure(); } const auto& shape = lhs_type.getShape(); SmallVector induction_vars; - const auto loc = op->getLoc(); + const auto loc = op.getLoc(); for (int i = 0; i < shape.size(); ++i) { auto forOp = rewriter.create(loc, 0, shape[i]); induction_vars.push_back(forOp.getInductionVar()); @@ -140,23 +137,26 @@ struct BinaryOpConverter : public RewritePattern { auto r = rewriter.create(loc, rhs, induction_vars); auto result = GetBinaryOp(element_type, loc, l, r, rewriter); if (result == nullptr) { - return matchFailure(); + return this->matchFailure(); } - rewriter.create(loc, result, binary_op.out(), induction_vars); + rewriter.create(loc, result, op.out(), induction_vars); rewriter.replaceOp(op, {}); - return matchSuccess(); + return this->matchSuccess(); } }; void populateLHLOToAffineConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); + // clang-format off + patterns->insert< + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter>(context); + // clang-format on } struct LhloLegalizeToAffine : public FunctionPass { diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index e64182889cb..06d945946f6 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -22,7 +22,11 @@ limitations under the License. #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -68,6 +72,23 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { } } +StatusOr TypeToShape( + mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) { + tensorflow::PartialTensorShape partial_tensor_shape = + tensorflow::ConvertTypeToTensorShape(type); + + tensorflow::TensorShape fully_defined_tensor_shape; + if (!partial_tensor_shape.AsTensorShape(&fully_defined_tensor_shape)) { + return tensorflow::errors::InvalidArgument( + "XLA HLO only allows fully-defined shape"); + } + + tensorflow::DataType dtype; + TF_RETURN_IF_ERROR(tensorflow::ConvertToDataType(type, &dtype)); + + return shape_representation_fn(fully_defined_tensor_shape, dtype); +} + Shape TypeToShape(mlir::Type type) { PrimitiveType ptype = TypeToPrimitiveType(type); if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h index 6bd5384f857..4bc3fac9b1c 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.h +++ b/tensorflow/compiler/mlir/xla/type_to_shape.h @@ -16,15 +16,29 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ +#include "llvm/ADT/STLExtras.h" #include "mlir/IR/Types.h" // TF:local_config_mlir #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace xla { // Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. Shape TypeToShape(mlir::Type type); +// Type of a custom function that converts a TensorFlow type and shape into an +// XLA shape with optional layout info. +typedef llvm::function_ref( + const tensorflow::TensorShape&, tensorflow::DataType)> + CustomShapeRepresentationFn; + +// Compute an XLA shape based in given MLIR type and an +// CustomShapeRepresentationFn, which allows setting custom layout in returned +// XLA shape. +StatusOr TypeToShape( + mlir::Type type, CustomShapeRepresentationFn shape_representation_fn); + // Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a // primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID. PrimitiveType TypeToPrimitiveType(mlir::Type type); diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index 57922fe1532..49a4a838e30 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -134,5 +134,43 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { EqualsProto(Shape().ToProto())); } +TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) { + tensorflow::DataType captured_dtype; + tensorflow::TensorShape captured_tensor_shape; + + // A dummy shape representation function that does nothing other than + // capturing arguments passed to it. + auto test_shape_representation_fn = [&](const tensorflow::TensorShape& shape, + tensorflow::DataType dtype) { + captured_tensor_shape = shape; + captured_dtype = dtype; + return xla::Shape(); + }; + + MLIRContext context; + Builder b(&context); + StatusOr status_or_shape; + + // Non-fully-defined shape. + status_or_shape = TypeToShape(b.getTensorType({-1, 2, 3}, b.getF32Type()), + test_shape_representation_fn); + EXPECT_EQ(status_or_shape.status().code(), + tensorflow::errors::Code::INVALID_ARGUMENT); + + // Scalar Int32 Tensor, using fast memory. + status_or_shape = + TypeToShape(b.getIntegerType(32), test_shape_representation_fn); + EXPECT_TRUE(status_or_shape.ok()); + EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_INT32); + EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape()); + + // Ranked Float32 Tensor, not using fast memory. + status_or_shape = TypeToShape(b.getTensorType({1, 2, 3}, b.getF32Type()), + test_shape_representation_fn); + EXPECT_TRUE(status_or_shape.ok()); + EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_FLOAT); + EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index ad7e4724d90..7fbc5e4e2bc 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Module.h" // TF:local_config_mlir @@ -30,6 +31,18 @@ limitations under the License. using stream_executor::port::Status; using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this +// NOLINTNEXTLINE +static llvm::cl::opt emit_use_tuple_arg( + "emit-use-tuple-args", + llvm::cl::desc("Emit HLO modules using tuples as args"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt emit_always_return_tuple( + "emit-always-return-tuple", + llvm::cl::desc("Emit HLO modules always return tuple"), + llvm::cl::init(false)); + namespace xla { namespace { @@ -122,7 +135,8 @@ static mlir::LogicalResult MlirHloToHloTranslateFunction( } HloProto hloProto; - Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + Status status = mlir::ConvertMlirHloToHlo( + module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); @@ -155,7 +169,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( } HloProto hloProto; - Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + Status status = mlir::ConvertMlirHloToHlo( + module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index b7784062e82..1d511ede1c2 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -22,6 +22,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -138,6 +139,22 @@ class StridedSliceTest(xla_test.XLATestCase): self.assertAllEqual([2, 4], result) + def test1DDynamic(self): + for dtype in self.numeric_types: + with self.session(): + i = array_ops.placeholder(dtype, shape=[10]) + begin = array_ops.placeholder(dtypes.int32, shape=[1]) + with self.test_scope(): + end = math_ops.add(begin, [1]) + o = array_ops.strided_slice(i, begin, end, [1]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + begin: [0] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([0], result) + def test1DNegativeStride(self): for dtype in self.numeric_types: with self.session(): @@ -179,6 +196,22 @@ class StridedSliceTest(xla_test.XLATestCase): self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + def test2DFullSlice(self): + for dtype in self.numeric_types: + with self.session(): + with self.test_scope(): + i = array_ops.placeholder(dtype, shape=[2, 4]) + begin = array_ops.placeholder(dtypes.int32, shape=[2]) + end = math_ops.add(begin, [1, 1]) + o = array_ops.strided_slice(i, begin, end, [1, 1]) + params = { + i: [[0, 1, 2, 3], [4, 5, 6, 7]], + begin: [1, 1] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[5]], result) + def test3D(self): for dtype in self.numeric_types: with self.session(): diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 43f920b9ccc..65d5f9a2ecd 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1566,8 +1566,11 @@ void Converter::MaybeApplyQuantizationRanges() { #endif if (use_calibration()) return; +#if !IS_TRT_VERSION_GE(6, 0, 0, 0) // Attempt to find tensors that are missing ranges, and set the corresponding // layer's precision to FP16 to avoid Builder::buildCudaEngine() failing. + // This is only needed for TensorRT 5 and before because + // TensorRT6 falls to FP16 internally. // TensorRT doesn't need ranges for intermediate tensors when layers are fused // so find fused layers first. // Get all tensors from network and deduce fused ops. @@ -1696,6 +1699,7 @@ void Converter::MaybeApplyQuantizationRanges() { } } } +#endif } void Converter::PropagateQuantizationRanges() { @@ -5211,6 +5215,18 @@ Status ConvertCombinedNMS(OpConverterParams* params) { &plugin_inputs[0], static_cast(plugin_inputs.size()), *plugin); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Set plugin outputs + nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1); +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + // TRT6 fixes (removes) the extra last dimension in CombinedNMS outputs + nvinfer1::ITensor* output_num_detections = layer->getOutput(0); + nvinfer1::ITensor* output_nmsed_scores = layer->getOutput(2); + nvinfer1::ITensor* output_nmsed_classes = layer->getOutput(3); +#else + nvinfer1::ITensor* output_num_detections = nullptr; + nvinfer1::ITensor* output_nmsed_scores = nullptr; + nvinfer1::ITensor* output_nmsed_classes = nullptr; + auto shrink_last_dim = [params](nvinfer1::ITensor* in_tensor, nvinfer1::ITensor** out_tensor) { nvinfer1::Dims dims = in_tensor->getDimensions(); @@ -5224,18 +5240,13 @@ Status ConvertCombinedNMS(OpConverterParams* params) { /*validation_only=*/false, out_tensor)); return Status::OK(); }; - - // Set plugin outputs - nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1); - nvinfer1::ITensor* output_nmsed_scores = nullptr; - nvinfer1::ITensor* output_nmsed_classes = nullptr; - nvinfer1::ITensor* output_num_detections = nullptr; TF_RETURN_IF_ERROR( shrink_last_dim(layer->getOutput(2), &output_nmsed_scores)); TF_RETURN_IF_ERROR( shrink_last_dim(layer->getOutput(3), &output_nmsed_classes)); TF_RETURN_IF_ERROR( shrink_last_dim(layer->getOutput(0), &output_num_detections)); +#endif // IS_TRT_VERSION_GE(6, 0, 0, 0) params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes)); params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_scores)); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index dedeb647023..88533727c27 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -459,6 +459,7 @@ tf_cc_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index a431abd26e0..34888fc0e2f 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -90,14 +90,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, break; case XlaExpression::Kind::kResource: { XlaResource* resource = expressions[i]->resource(); - - arg.initialized = resource->initialized(); - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind(); - arg.type = resource->type(); - arg.shape = resource->shape(); - arg.max_array_size = resource->max_array_size(); - arg.name = resource->name(); + XlaCompiler::PopulateArgumentFromResource(*resource, &arg); break; } case XlaExpression::Kind::kTensorList: { diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 47574a8c202..19c09b07959 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -220,9 +220,7 @@ XLA_MAKE_BINARY(SigmoidGrad, xla::Mul(xla::Mul(rhs, lhs), xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); -XLA_MAKE_BINARY(SoftplusGrad, - xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); +XLA_MAKE_BINARY(SoftplusGrad, xla::Mul(lhs, xla::Logistic(rhs))); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 5ba844e10bd..9b3770cf55e 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -93,20 +93,9 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); - - arg.initialized = resource->initialized(); - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind(); - - arg.type = resource->type(); - arg.shape = resource->shape(); + XlaCompiler::PopulateArgumentFromResource(*resource, &arg); OP_REQUIRES(ctx, arg.initialized, errors::Unimplemented("Uninitialized arguments: ", arg.name)); - arg.max_array_size = resource->max_array_size(); - for (const auto& gradient : resource->tensor_array_gradients()) { - arg.tensor_array_gradients.insert(gradient.first); - } - arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) << " shape: " << arg.HumanString() @@ -235,6 +224,22 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { branch_results[0].xla_output_shape); } + // Check that all branches have same TensorList output indices. + for (int output_index = 0; output_index < branch_results[0].outputs.size(); + output_index++) { + bool is_tensor_list_in_branch_0 = + branch_results[0].outputs[output_index].is_tensor_list; + bool is_tensor_list_in_branch_j = + branch_results[j].outputs[output_index].is_tensor_list; + OP_REQUIRES( + ctx, is_tensor_list_in_branch_0 == is_tensor_list_in_branch_j, + errors::FailedPrecondition("Output #", output_index, " is ", + (is_tensor_list_in_branch_0 ? "" : "not"), + " a TensorList in branch 0, but is ", + (is_tensor_list_in_branch_j ? "" : "not"), + " a TensorList in branch ", j)); + } + // We set return_updated_values_for_all_resources=true and we pass the same // arguments to both computations, so the resource update count must match. OP_REQUIRES(ctx, @@ -296,7 +301,12 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { LOG(INFO) << "Shape unknown for output " << i; } } - ctx->SetOutput(i, output_handle); + // We have checked that all branches have same TensorList output indices. + if (branch_results[0].outputs[i].is_tensor_list) { + ctx->SetTensorListOutput(i, output_handle); + } else { + ctx->SetOutput(i, output_handle); + } } if (has_token_input_output_) { // Set token output for this "Case" op. Token output is the last output of diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc index 6b3334dc1de..bf9313389dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -25,7 +25,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kEinsumTypes = {{DT_BFLOAT16, DT_FLOAT}}; +constexpr std::array kEinsumTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; class EinsumOp : public XlaOpKernel { public: @@ -38,8 +39,6 @@ class EinsumOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp lhs = ctx->Input(0); xla::XlaOp rhs = ctx->Input(1); - const TensorShape a_shape = ctx->InputShape(0); - const TensorShape b_shape = ctx->InputShape(1); ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_)); } @@ -49,6 +48,7 @@ class EinsumOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("XlaEinsum").TypeConstraint("T", kEinsumTypes), EinsumOp); +REGISTER_XLA_OP(Name("Einsum").TypeConstraint("T", kEinsumTypes), EinsumOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 5ac288d8a34..e5e4e797cc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -138,10 +138,20 @@ class RFFTOp : public GenericFftOp { explicit RFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>); -REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("RFFT") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), + RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), RFFTOp<2>); -REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("RFFT3D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), RFFTOp<3>); template @@ -150,11 +160,20 @@ class IRFFTOp : public GenericFftOp { explicit IRFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<1>); -REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT2D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<2>); -REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT3D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<3>); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 489ffd3fdad..3178c04875a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -150,6 +152,85 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, return Status::OK(); } +Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + const xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, xla::XlaOp* gather_output) { + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + + absl::optional axis; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + if (!TensorShapeUtils::IsScalar(axis_shape)) { + return errors::InvalidArgument("axis must be scalar"); + } + DataType axis_type = context->input_type(2); + if (axis_type != DT_INT32 && axis_type != DT_INT64) { + return errors::InvalidArgument("axis must be int32 or int64"); + } + + int64 axis_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input)); + + const auto params_dims = input_shape.dims(); + if (-params_dims > axis_input || axis_input >= params_dims) { + return errors::InvalidArgument("Expected axis in the range [", + -params_dims, ", ", params_dims, + "), but got ", axis_input); + } + if (axis_input < 0) { + axis_input += params_dims; + } + axis = axis_input; + } + + if (batch_dims != 0) { + if (batch_dims < 0) { + batch_dims = indices_shape.dims() + batch_dims; + } + + axis = axis.value_or(batch_dims); + + if (batch_dims < -indices_shape.dims() || + batch_dims >= indices_shape.dims()) { + return errors::InvalidArgument( + "Expected batch_dims in the range [", -indices_shape.dims(), ", ", + indices_shape.dims(), "), but got ", batch_dims); + } + + if (batch_dims >= input_shape.dims()) { + return errors::InvalidArgument("batch_dims (", batch_dims, + ") must be less than rank(input) (", + input_shape.dims(), ")."); + } + + if (*axis < batch_dims) { + return errors::InvalidArgument("batch_dims (", batch_dims, + ") must be less than or equal to ", + "axis (", *axis, ")."); + } + } + + axis = axis.value_or(0); + DataType index_type = context->input_type(1); + if (index_type != DT_INT32 && index_type != DT_INT64) { + return errors::InvalidArgument("indices must be int32 or int64"); + } + + xla::XlaOp gather; + if (batch_dims > 0) { + *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims); + } else { + // XlaGather() manages degenerate cases, like empty-indices, which are + // error conditions and caught above if batch_dims is not 0. + TF_RETURN_IF_ERROR( + XlaGather(input, input_shape, indices, indices_shape, *axis, + /*indices_are_nd=*/false, context->expected_output_dtype(0), + index_type, context->builder(), gather_output)); + } + return Status::OK(); +} class GatherOp : public XlaOpKernel { public: explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { @@ -164,76 +245,11 @@ class GatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { auto input = context->Input(0); auto input_shape = context->InputShape(0); - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - - absl::optional axis; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - int64 axis_input; - OP_REQUIRES_OK(context, - context->ConstantInputAsIntScalar(2, &axis_input)); - - const auto params_dims = input_shape.dims(); - OP_REQUIRES(context, - -params_dims <= axis_input && axis_input < params_dims, - errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input)); - if (axis_input < 0) { - axis_input += params_dims; - } - axis = axis_input; - } - - if (batch_dims_ != 0) { - if (batch_dims_ < 0) { - batch_dims_ = indices_shape.dims() + batch_dims_; - } - - axis = axis.value_or(batch_dims_); - - OP_REQUIRES(context, - batch_dims_ >= -indices_shape.dims() && - batch_dims_ < indices_shape.dims(), - errors::InvalidArgument("Expected batch_dims in the range [", - -indices_shape.dims(), ", ", - indices_shape.dims(), "), but got ", - batch_dims_)); - - OP_REQUIRES(context, batch_dims_ < input_shape.dims(), - errors::InvalidArgument("batch_dims (", batch_dims_, - ") must be less than rank(input) (", - input_shape.dims(), ").")); - - OP_REQUIRES(context, *axis >= batch_dims_, - errors::InvalidArgument("batch_dims (", batch_dims_, - ") must be less than or equal to ", - "axis (", *axis, ").")); - } - - axis = axis.value_or(0); - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("indices must be int32 or int64")); xla::XlaOp gather; - if (batch_dims_ > 0) { - gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_); - } else { - // XlaGather() manages degenerate cases, like empty-indices, which are - // error conditions and caught above if batch_dims is not 0. - OP_REQUIRES_OK( - context, XlaGather(input, input_shape, indices, indices_shape, *axis, - /*indices_are_nd=*/false, input_type(0), - index_type, context->builder(), &gather)); - } + OP_REQUIRES_OK(context, + XlaGatherWithBatchDimsOpImpl(context, input, input_shape, + batch_dims_, &gather)); context->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 92346283c31..7bd25230d46 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -39,6 +39,13 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, DataType index_type, xla::XlaBuilder* builder, xla::XlaOp* gather_output); +// The implementation of Gather and ResourceGather through XLA. Uses `input` as +// the input instead of context->input(0) in order to allow ResourceGather to +// handle obtaining the data from the ResourceVariable. +Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + const xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, xla::XlaOp* gather_output); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 4422af7d15f..a7dd1bb0079 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -93,24 +93,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); - arg.initialized = resource->initialized(); - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind(); - - arg.type = resource->type(); - arg.shape = resource->shape(); + XlaCompiler::PopulateArgumentFromResource(*resource, &arg); OP_REQUIRES(ctx, arg.initialized, errors::Unimplemented("Uninitialized arguments: ", arg.name)); - arg.max_array_size = resource->max_array_size(); - for (const auto& gradient : resource->tensor_array_gradients()) { - arg.tensor_array_gradients.insert(gradient.first); - } - arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) << " shape: " << arg.HumanString() << " initialized: " << arg.initialized; - num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; @@ -220,6 +209,22 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + // Check that both branches have same TensorList output indices. + for (int output_index = 0; output_index < then_result.outputs.size(); + output_index++) { + bool is_tensor_list_in_then_branch = + then_result.outputs[output_index].is_tensor_list; + bool is_tensor_list_in_else_branch = + else_result.outputs[output_index].is_tensor_list; + OP_REQUIRES( + ctx, is_tensor_list_in_then_branch == is_tensor_list_in_else_branch, + errors::FailedPrecondition("Output #", output_index, " is ", + (is_tensor_list_in_then_branch ? "" : "not"), + " a TensorList in then branch, but is ", + (is_tensor_list_in_else_branch ? "" : "not"), + " a TensorList in else branch")); + } + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); VLOG(2) << "Output shape: " << xla::ShapeUtil::HumanString(then_result.xla_output_shape); @@ -282,7 +287,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { LOG(INFO) << "Shape unknown for output " << i; } } - ctx->SetOutput(i, output_handle); + // We have checked that both branches have same TensorList output indices. + if (then_result.outputs[i].is_tensor_list) { + ctx->SetTensorListOutput(i, output_handle); + } else { + ctx->SetOutput(i, output_handle); + } } if (has_token_input_output_) { // Set token output for this "If" op. Token output is the last output of diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 273ac3dd5ae..bf9a9150ea6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -106,13 +106,34 @@ class ReshapeOp : public XlaOpKernel { " values, but the requested shape has ", shape.num_elements())); - VLOG(1) << "Reshape from " << input_shape.DebugString() << " to " + VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; + shape_input.clear(); + // Run get input again, this time with dynamic dimension represented as + // "-1" + ctx->set_dynamic_dimension_is_minus_one(true); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); + + int dynamic_dimension = -1; + + for (int d = 0; d < num_dims; ++d) { + const int32 size = shape_input[d]; + if (size == -1) { + if (dynamic_dimension == -1) { + dynamic_dimension = d; + } else { + if (unknown_index != d) { + dynamic_dimension = d; + } + } + } + } + // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference // in XLA to know which output dimension is dynamic. ctx->SetOutput(0, xla::ReshapeWithInferredDimension( - ctx->Input(0), shape.dim_sizes(), unknown_index)); + ctx->Input(0), shape.dim_sizes(), dynamic_dimension)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 93dfc189fd6..ce4a46b45c8 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -108,18 +110,26 @@ class ScatterNdOp : public XlaOpKernel { buffer_shape.dim_sizes()); auto indices = context->Input(0); auto updates = context->Input(1); + auto combine = + context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum; auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/Combine, builder); + /*indices_are_vectors=*/true, /*combiner=*/combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.ValueOrDie()); } private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, - xla::XlaBuilder* builder) { + static xla::XlaOp CombineNum(const xla::XlaOp x, const xla::XlaOp y, + xla::XlaBuilder* builder) { + (void)builder; return xla::Add(x, y); } + static xla::XlaOp CombineBool(const xla::XlaOp x, const xla::XlaOp y, + xla::XlaBuilder* builder) { + (void)builder; + return xla::Or(x, y); + } }; REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 88af12dacee..06095631434 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -119,7 +119,9 @@ class SizeOp : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto size = xla::One(builder, xla::U32); for (int64 i = 0; i < rank; ++i) { - size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); + size = xla::Mul( + size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i), + xla::U32)); } size = xla::ConvertElementType(size, ctx->output_xla_type(0)); ctx->SetOutput(0, size); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9da1504bff1..546b0e7f9e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/strided_slice_op.h" + #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" @@ -44,60 +46,124 @@ class StridedSliceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); + const TensorShape begin_shape = ctx->InputShape("begin"); + + OP_REQUIRES( + ctx, begin_shape.dims() == 1, + errors::InvalidArgument("'begin' input has to be a rank 1 vector")); - TensorShape final_shape; absl::InlinedVector begin; absl::InlinedVector end; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok(); + bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok(); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); Tensor begin_tensor, end_tensor, strides_tensor; - OP_REQUIRES_OK( - ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); - OP_REQUIRES_OK(ctx, - LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + if (begin_is_constant) { + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + } + if (end_is_constant) { + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + } OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - TensorShape dummy_processing_shape; + TensorShape final_shape; + PartialTensorShape dummy_processing_shape, partial_final_shape; bool dummy = false; - OP_REQUIRES_OK(ctx, - ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, input_shape, - begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &dummy_processing_shape, &final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( + begin_is_constant ? &begin_tensor : nullptr, + end_is_constant ? &end_tensor : nullptr, + strides_tensor, input_shape, begin_mask_, end_mask_, + ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &dummy_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); - absl::InlinedVector dimensions_to_reverse; - absl::InlinedVector slice_begin, slice_end, slice_strides; - - for (int i = 0; i < begin.size(); ++i) { - if (strides[i] > 0) { - slice_begin.push_back(begin[i]); - slice_end.push_back(std::max(end[i], begin[i])); - slice_strides.push_back(strides[i]); - } else { - // Negative stride: swap begin and end, add 1 because the interval - // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); - slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, - input_shape.dim_size(i) - begin[i] - 1)); - slice_strides.push_back(-strides[i]); - dimensions_to_reverse.push_back(i); - } - } + OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape), + errors::InvalidArgument( + "XLA can't deduce compile time constant output " + "shape for strided slice: ", + partial_final_shape.DebugString(), + ", output shape must be a compile-time constant")); xla::XlaOp slice = ctx->Input(0); - if (!dimensions_to_reverse.empty()) { - slice = xla::Rev(slice, dimensions_to_reverse); + if (begin_is_constant && end_is_constant) { + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_end, slice_strides; + for (int i = 0; i < begin.size(); ++i) { + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_end.push_back(std::max(end[i], begin[i])); + slice_strides.push_back(strides[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, + input_shape.dim_size(i) - begin[i] - 1)); + slice_strides.push_back(-strides[i]); + dimensions_to_reverse.push_back(i); + } + } + if (!dimensions_to_reverse.empty()) { + slice = xla::Rev(slice, dimensions_to_reverse); + } + slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); + } else { + // When output shape is fully defined, it must be a size one slice: + // + // 1. The number of output elements has to equal to number of input + // elements that are sliced. + // 2. The stride of the slice dimensions must be exact one. + int64 output_elements = final_shape.num_elements(); + + int64 input_elements_sliced = 1; + int64 slicing_dim_size = begin_shape.dim_size(0); + // We only support slicing major dimensions, so minor dimensions after + for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) { + input_elements_sliced *= input_shape.dim_size(d); + } + + OP_REQUIRES( + ctx, output_elements == input_elements_sliced, + errors::InvalidArgument( + "The number of output elements ", output_elements, + " has to equal to number of input elements that are sliced ", + input_elements_sliced, " when input indices are not constant.")); + + for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) { + OP_REQUIRES( + ctx, strides[i] == 1, + errors::InvalidArgument( + "Strides have to be one when inputs are not constant.")); + } + + // When inputs are not compile time constants, shape inference can only + // inference size 1 slice. + std::vector slice_sizes(slicing_dim_size, 1); + std::vector start_indices; + for (int64 d = 0; d < slicing_dim_size; ++d) { + auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1}); + // Convert index to scalar. + start_indices.push_back(xla::Reshape(index, {})); + } + + for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) { + // For non-slice dims, naturally we get the full slice starting from 0. + slice_sizes.push_back(input_shape.dim_size(d)); + start_indices.push_back( + xla::Zero(ctx->builder(), ctx->InputXlaType("begin"))); + } + + std::vector output_shape_dim_sizes; + slice = xla::DynamicSlice(slice, start_indices, slice_sizes); } - - slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); - slice = xla::Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 7b4125ab76e..60424f85840 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -122,27 +123,24 @@ REGISTER_XLA_OP( class ResourceGatherOp : public XlaOpKernel { public: - explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_)); + } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - DataType type = ctx->expected_output_dtype(0); - TensorShape resource_shape; - xla::XlaOp resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, - &resource_handle)); + TensorShape input_shape; + xla::XlaOp input; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input)); - auto indices = ctx->Input(1); - auto indices_shape = ctx->InputShape(1); - DataType index_type = ctx->input_type(1); xla::XlaOp gather; - OP_REQUIRES_OK( - ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, - /*axis=*/0, /*indices_are_nd=*/false, type, index_type, - builder, &gather)); + OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape, + batch_dims_, &gather)); ctx->SetOutput(0, gather); } + + private: + int32 batch_dims_; }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index e82519f68ca..36c35f3c83b 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -81,29 +81,17 @@ Status MakeXlaCompilerArgumentsFromInputs( if (type == DT_RESOURCE) { XlaResource* resource; TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); - - arg.initialized = resource->initialized(); - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind(); + XlaCompiler::PopulateArgumentFromResource(*resource, &arg); if (arg.resource_kind == XlaResource::kTensorArray) { *has_tensor_arrays = true; } - - arg.type = resource->type(); - arg.shape = resource->shape(); if (!arg.initialized) { *has_uninitialized_vars = true; } - arg.max_array_size = resource->max_array_size(); - for (const auto& gradient : resource->tensor_array_gradients()) { - arg.tensor_array_gradients.insert(gradient.first); - } - arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) << " shape: " << arg.ShapeHumanString() << " initialized: " << arg.initialized; - } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = type; diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f6fc0c526d9..43d9e5d0e10 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -624,7 +624,7 @@ REGISTER_OP("XlaEinsum") .Input("b: T") .Output("product: T") .Attr("equation: string") - .Attr("T: {bfloat16, float}") + .Attr("T: {complex64, bfloat16, float}") .SetShapeFn([](shape_inference::InferenceContext* context) { shape_inference::ShapeHandle input_a = context->input(0); shape_inference::ShapeHandle input_b = context->input(1); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f1bf88b5418..54beb0aebfe 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -161,9 +161,10 @@ Status BuildComputation( const std::vector>& resources, std::unique_ptr token_output, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, + bool is_entry_computation, bool return_updated_values_for_all_resources, + bool always_return_tuple, xla::XlaBuilder* builder, + xla::XlaComputation* computation, int* num_computation_outputs, + int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates, xla::Shape* output_shape) { @@ -173,6 +174,7 @@ Status BuildComputation( xla::OpMetadata retval_metadata; retval_metadata.set_op_name("XLA_Retvals"); builder->SetOpMetadata(retval_metadata); + VLOG(1) << "Building new computation"; auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); // Builds a no-op XLA computation. We need to set the sharding of outputs, but @@ -189,6 +191,10 @@ Status BuildComputation( // a descending layout is used. The first element is the output index, second // element is the new layout. std::vector> retval_index_and_layout; + // Keeps track of sharding of each retval. If a retval is not in this list, + // replicate sharding is used. The first element is the output index, second + // element is the sharding. + std::unordered_map retval_index_and_sharding; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -216,6 +222,9 @@ Status BuildComputation( builder, it == retval_shardings.end() ? absl::optional() : it->second); + if (it != retval_shardings.end()) { + retval_index_and_sharding[elems.size()] = it->second; + } if (shape_representation_fn) { // If there is a shape representation function, reshape the output // tensor to the shape given by the representation shape function. @@ -290,6 +299,9 @@ Status BuildComputation( xla::XlaScopedShardingAssignment assign_sharding( builder, it == arg_shardings.end() ? absl::optional() : it->second); + if (it != arg_shardings.end()) { + retval_index_and_sharding[elems.size()] = it->second; + } xla::XlaOp handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); @@ -334,7 +346,44 @@ Status BuildComputation( // Builds the XLA computation. We *always* form a tuple here to ensure that // the output value is the last thing added into the XLA computation, even // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); + xla::XlaOp tuple; + if (retval_index_and_sharding.empty() || !is_entry_computation) { + tuple = xla::Tuple(builder, elems); + } else { + std::vector elem_shapes; + for (const auto& elem : elems) { + TF_ASSIGN_OR_RETURN(xla::Shape elem_shape, + elem.builder()->GetShape(elem)); + elem_shapes.push_back(elem_shape); + } + xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); + // Copy specified sharding from retval_index_and_sharding. + std::vector sharding_elems; + for (int i = 0; i < elems.size(); i++) { + const auto& iter = retval_index_and_sharding.find(i); + TF_RET_CHECK(iter != retval_index_and_sharding.end()); + const xla::OpSharding& sub_op_sharding = iter->second; + TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding, + xla::HloSharding::FromProto(sub_op_sharding)); + if (elem_shapes[i].IsTuple()) { + const std::vector sub_sharding_elems = + sub_sharding.tuple_elements(); + TF_RET_CHECK(sub_sharding_elems.size() == + xla::ShapeUtil::GetLeafCount(elem_shapes[i])); + for (const auto& sub_sharding_elem : sub_sharding_elems) { + sharding_elems.push_back(sub_sharding_elem); + } + } else { + sharding_elems.push_back(sub_sharding); + } + } + xla::HloSharding modified_sharding = + xla::HloSharding::Tuple(shape, sharding_elems); + xla::OpSharding op_sharding = modified_sharding.ToProto(); + // Assign proper sharding to the tuple instruction. + xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding); + tuple = xla::Tuple(builder, elems); + } if (!always_return_tuple && elems.size() == 1) { xla::GetTupleElement(tuple, 0); } @@ -793,6 +842,22 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } +/* static */ +void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource, + Argument* arg) { + arg->initialized = resource.initialized(); + arg->kind = XlaCompiler::Argument::kResource; + arg->resource_kind = resource.kind(); + + arg->type = resource.type(); + arg->shape = resource.shape(); + arg->max_array_size = resource.max_array_size(); + for (const auto& gradient : resource.tensor_array_gradients()) { + arg->tensor_array_gradients.insert(gradient.first); + } + arg->name = resource.name(); +} + // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( @@ -915,6 +980,9 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + VLOG(1) << "Setting dynamic binding " << i << " -> " + << dynamic_size_param_index; + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, /*target_param_num=*/0, /*target_param_index=*/{i}, @@ -1170,7 +1238,7 @@ Status XlaCompiler::CompileGraph( std::unique_ptr graph, absl::Span args, absl::Span user_aliases, CompilationResult* result) { - VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; + VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( graph.get(), options_.flib_def, local_flib_def_.get())); @@ -1291,6 +1359,7 @@ Status XlaCompiler::CompileGraph( std::move(token_output), options.is_entry_computation ? options_.shape_representation_fn : ShapeRepresentationFn{}, + options.is_entry_computation, options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 6ab8bde542d..4b4ee02aad9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -351,6 +351,10 @@ class XlaCompiler { ~XlaCompiler(); + // Helper function to populate an XlaCompiler::Argument from XlaResource. + static void PopulateArgumentFromResource(const XlaResource& resource, + Argument* arg); + Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, absl::Span args, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4413625dc3c..324c31e8bf9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -1738,5 +1739,56 @@ TEST_F(XlaCompilerTest, WhileWithResources) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } +TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) { + // Builds a graph that returns its only argument. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("B"), a, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Sets _XlaSharding attribute for the _Retval node. + auto node_name_index = graph->BuildNodeNameIndex(); + Node* ret_node = node_name_index["B"]; + ASSERT_NE(ret_node, nullptr); + xla::Array tile_assignment({2}); + tile_assignment.FillIota(0); + xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment); + ret_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString()); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test", + std::move(graph), args, + /*user_aliases=*/{}, &result)); + + // Tests that we set sharding on the root TUPLE instruction. + const auto& hlo_module_proto = result.computation->proto(); + ASSERT_EQ(hlo_module_proto.computations_size(), 1); + const auto& hlo_computation_proto = hlo_module_proto.computations(0); + absl::optional root_instruction_proto; + for (const auto& inst : hlo_computation_proto.instructions()) { + if (inst.id() == hlo_computation_proto.root_id()) { + root_instruction_proto = inst; + break; + } + } + ASSERT_TRUE(root_instruction_proto); + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2})}); + xla::HloSharding tuple_sharding = xla::HloSharding::Tuple( + tuple_shape, std::vector{sharding}); + EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(), + tuple_sharding.ToProto().SerializeAsString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 3d228c92adc..0aa139ce4f0 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -102,7 +102,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { } xla::StatusOr> XlaExpression::ResolveConstant( - xla::Client* client) const { + xla::Client* client, bool dynamic_dimension_is_minus_one) const { switch (kind()) { case Kind::kConstant: return {constant_value()}; @@ -122,7 +122,8 @@ xla::StatusOr> XlaExpression::ResolveConstant( if (!is_constant) return {absl::nullopt}; TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, - handle().builder()->BuildConstantSubGraph(handle())); + handle().builder()->BuildConstantSubGraph( + handle(), dynamic_dimension_is_minus_one)); TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index ac0232d8924..5d0bb35b182 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -97,7 +97,7 @@ class XlaExpression { // optional if it cannot be resolved. Returns an error if passed a resource // expression. xla::StatusOr> ResolveConstant( - xla::Client* client) const; + xla::Client* client, bool dynamic_dimension_is_minus_one = false) const; // Returns the shape of the tensor. // The shape of a resource is the shape of a resource handle (i.e., a scalar), diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c95cd4e5475..a1941cc5fdf 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -31,7 +31,7 @@ limitations under the License. namespace tensorflow { XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context) - : context_(context) {} + : context_(context), dynamic_dimension_is_minus_one_(false) {} bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); @@ -166,7 +166,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( xla::Literal* constant_literal) { XlaExpression e = InputExpression(index); xla::StatusOr> constant_or_status = - e.ResolveConstant(compiler()->client()); + e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_); if (!constant_or_status.ok()) { Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 7794786f905..3e75cf7fa58 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -202,6 +202,17 @@ class XlaOpKernelContext { Status GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const; + // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension + // returns "-1", this is useful when the underlying ops expect explicit + // dynamic index like reshape. + void set_dynamic_dimension_is_minus_one(bool value) { + dynamic_dimension_is_minus_one_ = value; + } + + bool dynamic_dimension_is_minus_one() const { + return dynamic_dimension_is_minus_one_; + } + // Reads the current value of the resouce variable referred to by input // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if @@ -280,6 +291,7 @@ class XlaOpKernelContext { xla::Literal* constant_literal); OpKernelContext* const context_; + bool dynamic_dimension_is_minus_one_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 24105f2f162..b4752813c8c 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -58,6 +58,16 @@ xla_proto_library( ], ) +tf_proto_library_py( + name = "xla_proto", # bzl adds a _py suffix + srcs = ["xla.proto"], + visibility = ["//visibility:public"], + deps = [ + ":xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", + ], +) + cc_library( name = "bit_cast", hdrs = ["bit_cast.h"], diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b46d04dc328..38a34dfb563 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -310,7 +310,9 @@ xla_test( srcs = ["slicing_test.cc"], deps = [ ":slicing", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 0940a873fa4..de573429fdc 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -27,9 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace { - -using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, XlaBuilder* builder, @@ -45,69 +42,50 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, const Shape scalar = ShapeUtil::MakeShape(type, {}); auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); - generator(b.get(), lhs, rhs); + generator(lhs, rhs); return b->BuildAndNoteError(); } -} // namespace - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "add", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Add(lhs, rhs); - }); + "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); }); } XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "mul", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Mul(lhs, rhs); - }); + "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); }); } XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation("ge", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, - const XlaOp& rhs) { return Ge(lhs, rhs); }); + return CreateScalarComputation( + "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); }); } XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "max", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Max(lhs, rhs); - }); + "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); }); } XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "min", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Min(lhs, rhs); - }); + "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); }); } XlaComputation CreateScalarAndComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "and", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return And(lhs, rhs); - }); + "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); }); } XlaComputation CreateScalarOrComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation("or", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, - const XlaOp& rhs) { return Or(lhs, rhs); }); + return CreateScalarComputation( + "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); }); } XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 270076a1586..350dcc5531d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -24,6 +24,13 @@ limitations under the License. namespace xla { +using XlaOpGenerator = std::function; + +// Creates a scalar computation based on a lambda and returns it. +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator); + // Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 5a7c826c389..6c4b9f9c973 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -207,14 +207,39 @@ StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { auto new_x = Mul(x, predecessor_mask, /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); - a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); + // Update a[:,j] + std::vector dim_ids(num_dims); + std::iota(dim_ids.begin(), dim_ids.end(), 0); + new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), + /*broadcast_dimensions=*/dim_ids); + const int64 minor_dim = batch_dims.size(); + auto iota_mn = Iota( + builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), + minor_dim + 1); + a = Select(Eq(iota_mn, j), new_x, a); // vs[:, j] = v - vs = DynamicUpdateSliceInMinorDims( - vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + std::vector vs_broadcast_dims(batch_dims.size() + 1); + std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0); + auto vs_zeros = ZerosLike(vs); + auto vs_update = Select( + Eq(iota_mn, j), + Add(vs_zeros, v, /*broadcast_dimensions=*/vs_broadcast_dims), vs_zeros); + vs = vs + vs_update; + // taus[j] = tau - taus = DynamicUpdateSliceInMinorDims( - taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + std::vector tau_broadcast_dims(batch_dims.size()); + std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); + + auto iota_n = + Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})), + minor_dim); + auto taus_zeros = ZerosLike(taus); + auto taus_update = Select( + Eq(iota_n, j), + Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), + taus_zeros); + taus = taus + taus_update; return std::vector{a, vs, taus}; }; diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 83c1045448d..b47ddb7919f 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -208,6 +208,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) { }); } +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim, + const std::function& combiner) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + std::vector index_broacast_dims; + std::vector sizes; + for (int64 i = 0; i < index_shape.rank(); ++i) { + if (i < dim) { + index_broacast_dims.push_back(i); + } else { + if (i == dim) { + sizes.push_back(input_shape.dimensions(i)); + } + index_broacast_dims.push_back(i + 1); + } + sizes.push_back(index_shape.dimensions(i)); + } + auto mask = + Eq(BroadcastInDim(index, sizes, index_broacast_dims), + Iota(builder, + ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim)); + auto masked_src = + Select(mask, BroadcastInDim(src, sizes, index_broacast_dims), + Zeros(builder, + ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + + return combiner( + input, + Reduce(masked_src, Zero(builder, input_shape.element_type()), + CreateScalarComputation("reducer", input_shape.element_type(), + builder, combiner), + {dim + 1})); + }); +} + XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { @@ -238,10 +275,8 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) { } for (int64 i = 0; i < input_shape.rank(); ++i) { if (i < batch_dims || i == dim) { - if (slice_sizes[i] != 0) { - slice_sizes[i] = 1; - gather_dnums.add_collapsed_slice_dims(i); - } + slice_sizes[i] = std::min(slice_sizes[i], 1); + gather_dnums.add_collapsed_slice_dims(i); gather_dnums.add_start_index_map(i); } else { if (i < dim) { diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 9a59a048b9f..cf83d63cec2 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -57,6 +57,13 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, // `index`. XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true); +// idx = index[i][j][k] +// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 +// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 +// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim, + const std::function& combiner); + // Returns a new tensor which indexes the input tensor along dimension dim using // the entries in index. // diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 107cbae0a73..8e2e713c45c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -130,6 +132,24 @@ XLA_TEST_F(SlicingTest, TorchGatherDense) { {input_data.get(), index_data.get()}); } +XLA_TEST_F(SlicingTest, TorchScatterDense) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp src, index, input; + auto input_data = CreateR2Parameter({{0, 0, 0}, {0, 0, 0}}, 0, "input", + &builder, &input); + auto index_data = + CreateR2Parameter({{1, 0}, {1, 2}}, 1, "index", &builder, &index); + auto src_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 2, "src", &builder, &src); + TorchScatterDense(input, index, src, 1, + [](XlaOp l, XlaOp r) { return l + r; }); + + ComputeAndCompareR2( + &builder, {{2, 1, 0}, {0, 3, 4}}, + {input_data.get(), index_data.get(), src_data.get()}); +} + XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) { xla::XlaBuilder builder(TestName()); @@ -180,6 +200,35 @@ XLA_TEST_F(SlicingTest, EmptyIndexSelect) { {input_data.get(), index_data.get()}); } +XLA_TEST_F(SlicingTest, DoubleEmptyIndexSelect) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + Literal l(ShapeUtil::MakeShape(F32, {0, 1, 2, 0})); + Literal i(ShapeUtil::MakeShape(S32, {0})); + auto input_data = + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); + auto index_data = + CreateParameterAndTransferLiteral(1, i, "index", &builder, &index); + TorchIndexSelect(input, index, 0); + ComputeAndCompareLiteral(&builder, l, {input_data.get(), index_data.get()}); +} + +XLA_TEST_F(SlicingTest, EmptyIndexSelectNonZero) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + Literal l(ShapeUtil::MakeShape(F32, {0, 2})); + auto input_data = + CreateParameterAndTransferLiteral(0, l, "input", &builder, &input); + auto index_data = + CreateR1Parameter({0, 0, 0}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 0); + ComputeAndCompareR2(&builder, + {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}, + {input_data.get(), index_data.get()}); +} + XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) { xla::XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index b697fb031fd..f5e66c6d586 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -37,6 +37,11 @@ namespace xla { class LocalExecutable { public: + // Low-level constructor; LocalClient::Compile() is the usual way to create + // executables. + LocalExecutable(std::unique_ptr executable, Backend* backend, + ExecutableBuildOptions build_options); + // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( @@ -56,13 +61,6 @@ class LocalExecutable { Executable* executable() const { return executable_.get(); } private: - // Only a local client can construct these objects. - friend class LocalClient; - - // Constructor invoked by LocalClient. - LocalExecutable(std::unique_ptr executable, Backend* backend, - ExecutableBuildOptions build_options); - // Validates that the given arguments and options satisfy various constraints // of the computation. // diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7fea245f69a..7a9b9856271 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -213,16 +213,10 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // TODO(b/32495713): We aren't checking the called computations. break; case HloOpcode::kGetDimensionSize: { - int64 dimension_number = instr.dimensions(0); - const HloInstructionProto& operand = - *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie()); - Shape operand_shape(operand.shape()); - if (operand_shape.is_dynamic_dimension(dimension_number)) { - *is_constant = false; - } + // DimensionSize is always considered constant in XLA -- If a dynamic + // dimension is presented, uint_max is returned. break; } - // Non functional ops. case HloOpcode::kRng: case HloOpcode::kAllReduce: @@ -268,8 +262,8 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, for (int64 index : target_param_index) { param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); } - // TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA - // backend can handle dynamic dimensions. + param_shape_ptr->set_dynamic_dimension(target_dim_num, + /*is_dynamic=*/true); *instr.mutable_shape() = param_shape.ToProto(); } } @@ -435,6 +429,7 @@ StatusOr XlaBuilder::InDimBroadcast( for (int64 dim : broadcast_dimensions) { instr.add_dimensions(dim); } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); } @@ -468,11 +463,21 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, << operand_shape << "; output_shape: " << output_shape; } } + + Shape reshaped_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), reshaped_dimensions); + + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand_shape, reshaped_shape); + + for (auto& unmodified : unmodified_dims) { + if (operand_shape.is_dynamic_dimension(unmodified.first)) { + reshaped_shape.set_dynamic_dimension(unmodified.second, true); + } + } + // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, - Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), - reshaped_dimensions), - operand)); + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand)); // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -2428,7 +2433,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { } StatusOr XlaBuilder::BuildConstantSubGraph( - const XlaOp& root_op) { + XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); if (!is_constant) { auto op_status = LookUpInstruction(root_op); @@ -2483,9 +2488,12 @@ StatusOr XlaBuilder::BuildConstantSubGraph( TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, LookUpInstructionByHandle(operand_handle)); - TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension)); - auto constant_dimension_size = - static_cast(operand_proto->shape().dimensions(dimension)); + int32 constant_dimension_size = -1; + if (!(operand_proto->shape().is_dynamic_dimension(dimension) && + dynamic_dimension_is_minus_one)) { + constant_dimension_size = + static_cast(operand_proto->shape().dimensions(dimension)); + } Literal literal = LiteralUtil::CreateR0(constant_dimension_size); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 187cd261833..693ea3c493e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -258,7 +258,8 @@ class XlaBuilder { // compile-time constant (see `IsConstant`), returns an error. // // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op); + StatusOr BuildConstantSubGraph( + XlaOp root_op, bool dynamic_dimension_is_uint_max = false); // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 701729b94f3..32a34c801f0 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -917,10 +917,7 @@ TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] Select(pred, gte0, gte1); Status status = BuildHloModule(&b).status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr("Operands to select must be the same shape; " - "got f32[4,<=5,6] and f32[4,5,<=6]")); + ASSERT_IS_OK(status); } TEST_F(XlaBuilderTest, DynamicTranspose) { diff --git a/tensorflow/compiler/xla/cpu_function_runtime.h b/tensorflow/compiler/xla/cpu_function_runtime.h index 281ca5b2203..0c3355cbbfb 100644 --- a/tensorflow/compiler/xla/cpu_function_runtime.h +++ b/tensorflow/compiler/xla/cpu_function_runtime.h @@ -138,6 +138,17 @@ class BufferInfo { // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. constexpr size_t kAlign = 64; +// When declaring variables that will be passed to an XLA instance as input via +// set_arg_data(), be it a regular input or a resource variable in the graph, +// the C++ variables must be aligned. +// +// Example usage: +// XLA_ALIGN std::array arg_x; +// XLA_ALIGN float arg_y; +// xla_instance.set_arg_data(0, arg_x.date()); +// xla_instance.set_arg_data(0, &arg_y); +#define XLA_ALIGN alignas(xla::cpu_function_runtime::kAlign) + // AlignedBufferBytes returns the sum of the size of each buffer in // `buffer_infos`, skipping constants, on-stack buffers and, if // allocate_entry_params is false, entry parameters. There are `n` entries in diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 13173e0dbc8..ec0059d37d9 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -436,10 +436,10 @@ static void AllocateFlags() { "behavior to help run tests on the host that run models in parallel " "across multiple devices."), tensorflow::Flag( - "xla_gpu_disable_ptxas_optimizations", + "xla_gpu_disable_gpuasm_optimizations", bool_setter_for( - &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), - flag_values->xla_gpu_disable_ptxas_optimizations(), + &DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), + flag_values->xla_gpu_disable_gpuasm_optimizations(), "In XLA:GPU run ptxas in -O0 (default is -O3)."), tensorflow::Flag( "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index 300abff395d..1234d01755b 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -214,6 +214,7 @@ class Layout { element_size_in_bits_ = value; return *this; } + static constexpr int64 kDefaultMemorySpace = 0; int64 memory_space() const { return memory_space_; } Layout& set_memory_space(int64 value) { memory_space_ = value; diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 03b47ba7089..c949cc6a5ba 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -944,6 +944,8 @@ absl::optional LiteralBase::GetAsComplex128( return {Get(multi_index)}; case C128: return {Get(multi_index)}; + case S8: + return {Get(multi_index)}; default: return absl::nullopt; } diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 1d9bd1f0695..eb8be012176 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -242,10 +242,35 @@ PyLocalClient::PyLocalClient( allocator_ = client_->backend().memory_allocator(); } + local_devices_.resize(device_states_.size()); for (const std::shared_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); + + if (device->local_device_ordinal() != -1) { + int idx = device->local_device_ordinal(); + CHECK(local_devices_[idx] == nullptr) << idx; + CHECK_LT(idx, local_devices_.size()); + local_devices_[idx] = device; + } } + for (int idx = 0; idx < local_devices_.size(); ++idx) { + CHECK(local_devices_[idx] != nullptr) << idx; + } +} + +StatusOr PyLocalClient::SerializeExecutable( + const PyLocalExecutable& executable) const { + return Unimplemented("Cannot serialize executables on platform '%s'", + platform_name()); +} + +StatusOr> +PyLocalClient::DeserializeExecutable( + const std::string& serialized, + std::shared_ptr this_shared) const { + return Unimplemented("Cannot deserialize executables on platform '%s'", + platform_name()); } Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 37b3c56b7d2..1cc8175b402 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -38,6 +38,8 @@ limitations under the License. namespace xla { +class PyLocalExecutable; + class Device { public: explicit Device(int id, int local_device_ordinal, int host_id = 0) @@ -127,14 +129,17 @@ class PyLocalClient { int num_replicas) const; int device_count() const { return devices_.size(); } + int local_device_count() const { return local_devices_.size(); } const std::vector>& devices() { return devices_; } + const std::vector>& local_devices() { + return local_devices_; + } const std::map>& id_to_device() const { return id_to_device_; } int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } - int local_device_count() const { return device_states_.size(); } DeviceState& device_state(int device_ordinal) const { return *device_states_.at(device_ordinal); } @@ -149,14 +154,29 @@ class PyLocalClient { return &h2d_transfer_pool_; } + // Returns a platform-specific serialization of `executable`. This is meant + // for transferring executables and not for storage, and the serialization is + // not guaranteed to be stable over time. + virtual StatusOr SerializeExecutable( + const PyLocalExecutable& executable) const; + + // Deserializes a serialized executable as produced by + // SerializeExecutable(). `serialized` must have been produced by client of + // the same platform. `this_shared` should point to this PyLocalClient. + virtual StatusOr> DeserializeExecutable( + const std::string& serialized, + std::shared_ptr this_shared) const; + protected: std::string platform_name_; LocalClient* client_; // Includes all devices, including non-local devices on multi-host platforms. std::vector> devices_; - // Maps Device::id() to the corresponding Device. + // Maps Device::id() to the corresponding Device. Includes all devices. std::map> id_to_device_; + // Local devices indexed by local device ordinal. + std::vector> local_devices_; int host_id_; // Device states local to this host. Indexed by local device ordinal. @@ -203,6 +223,7 @@ class PyLocalBuffer { const Shape& on_host_shape() const { return on_host_shape_; } int device_ordinal() const { return device_ordinal_; } const std::string& platform_name() const { return client_->platform_name(); } + std::shared_ptr client() const { return client_; } // Returns the buffer's value as a tuple DAG of Python arrays. If the value // has previously been prefetched to the host, then returns the prefetched @@ -299,6 +320,8 @@ class PyLocalExecutable { void Delete() { executable_ = nullptr; } + LocalExecutable* executable() const { return executable_.get(); } + private: StatusOr> ExecuteHelper( absl::Span argument_handles, int replica, diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 078fee8f652..08bfe78c47b 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -344,6 +344,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("device_count", &PyLocalClient::device_count) .def("local_device_count", &PyLocalClient::local_device_count) .def("devices", &PyLocalClient::devices) + .def("local_devices", &PyLocalClient::local_devices) .def("host_id", &PyLocalClient::host_id) .def("TransferToInfeed", [](PyLocalClient* client, const LiteralSlice& literal, @@ -364,7 +365,15 @@ PYBIND11_MODULE(xla_extension, m) { literal_shared = std::make_shared(std::move(literal)); } return LiteralToPython(std::move(literal_shared)); - }); + }) + .def("SerializeExecutable", + [](PyLocalClient* client, + PyLocalExecutable* executable) -> StatusOr { + TF_ASSIGN_OR_RETURN(std::string serialized, + client->SerializeExecutable(*executable)); + return py::bytes(serialized); + }) + .def("DeserializeExecutable", &PyLocalClient::DeserializeExecutable); py::class_(m, "PyLocalBuffer") .def_static( @@ -417,7 +426,12 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal)); }) .def("shape", &PyLocalBuffer::on_host_shape) - .def("device", &PyLocalBuffer::device_ordinal) + .def("device", + [](PyLocalBuffer* buffer) -> std::shared_ptr { + return buffer->client()->local_devices()[buffer->device_ordinal()]; + }) + // TODO(skye): get rid of `device_ordinal` once everything uses `device` + .def("device_ordinal", &PyLocalBuffer::device_ordinal) .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", [](const PyLocalBuffer& buffer) { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 7abd2f7429d..4dcf3a26301 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -137,6 +137,12 @@ class LocalBackend(Backend): options, self.client, compile_options.device_assignment) + def serialize(self, executable): + return self.client.SerializeExecutable(executable) + + def deserialize(self, serialized_executable): + return self.client.DeserializeExecutable(serialized_executable, self.client) + xla_platform_names = { 'cpu': 'Host', diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc index 147aafc356a..b2d7bbb829a 100644 --- a/tensorflow/compiler/xla/python/xrt.cc +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -148,7 +148,10 @@ void AddXrtSubmodule(py::module* module) { }) .def("delete", &XrtBuffer::Delete) .def("destructure", &XrtBuffer::DestructureTuple) + // TODO(skyewm): remove after we update jax to call device_ordinal instead + // of device. .def("device", &XrtBuffer::xrt_device_ordinal) + .def("device_ordinal", &XrtBuffer::xrt_device_ordinal) .def("shape", &XrtBuffer::shape) .def("is_deleted", [](const XrtBuffer& buffer) { return !buffer.handle().valid(); }) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 25f88004f98..46d014f48d8 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -945,9 +945,10 @@ cc_library( deps = [ ":service", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler", "//tensorflow/core:stream_executor_no_cuda", - ], + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler", + ]), ) cc_library( @@ -2289,7 +2290,7 @@ cc_library( ], ) -tf_cc_test( +xla_test( name = "dynamic_padder_test", srcs = ["dynamic_padder_test.cc"], deps = [ @@ -2306,7 +2307,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", ], ) @@ -2589,9 +2592,13 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) @@ -2873,6 +2880,7 @@ cc_library( ":call_graph", ":computation_layout", ":hlo", + ":hlo_alias_analysis", ":hlo_casting_utils", ":hlo_dce", ":hlo_graph_dumper", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 0296805a24b..1cfd1196508 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1878,14 +1878,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { const Shape& operand_shape = gather->operand(0)->shape(); + if (ShapeUtil::IsZeroElementArray(operand_shape)) { + return ReplaceInstruction(gather, MakeScalarLike(gather, 0)); + } // If the operand of a gather is very small, it is easier to fuse a // sequence of selects. + const Shape& index_shape = gather->operand(1)->shape(); if (operand_shape.rank() == 1 && operand_shape.dimensions(0) <= options_.very_small_gather_size() && gather->gather_dimension_numbers().index_vector_dim() == - gather->operand(1)->shape().rank() && + index_shape.rank() && gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) { - const Shape& index_shape = gather->operand(1)->shape(); const int64 operand_elements = operand_shape.dimensions(0); auto get_value = [&](int64 i) { auto slice = computation_->AddInstruction(HloInstruction::CreateSlice( @@ -2165,13 +2168,34 @@ Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { return Status::OK(); } - // ln(pow(A,B)) => B*ln(A) + // ln(pow(A,B)) => B*ln(abs(A)) + // or B*ln(A) if A is complex. if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { + auto abs_a = ShapeUtil::ElementIsComplex(a->shape()) + ? a + : computation_->AddInstruction(HloInstruction::CreateUnary( + log->shape(), HloOpcode::kAbs, a)); + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a)); + return ReplaceWithNewInstruction( + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, b)); + } + + if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) { auto new_log = computation_->AddInstruction( HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); return ReplaceWithNewInstruction( log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, - new_log, b)); + new_log, MakeScalarLike(log, 0.5))); + } + + if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) { + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); + return ReplaceWithNewInstruction( + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, MakeScalarLike(log, -0.5))); } return Status::OK(); @@ -2574,6 +2598,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, a_times_b)); } + VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); if (IsAll(rhs, 2)) { return ReplaceWithNewInstruction( @@ -3158,6 +3183,24 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } + // Try to simplify concat -> slice to an operand of concat. + if (slice->operand(0)->opcode() == HloOpcode::kConcatenate && + IsUnstridedSlice(slice)) { + auto concat = slice->operand(0); + int64 concat_dim = concat->concatenate_dimension(); + int64 piece_start = 0; + for (auto piece : concat->operands()) { + if (!SameShape(piece, slice)) { + piece_start += piece->shape().dimensions(concat_dim); + continue; + } + if (slice->slice_starts(concat_dim) == piece_start) { + return ReplaceInstruction(slice, piece); + } + piece_start += piece->shape().dimensions(concat_dim); + } + } + // Do not try to reorder slices and reshapes after layout assignment as it may // be invalid. if (!options_.is_layout_sensitive()) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f918634e075..3e4c906a4a5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1273,9 +1273,57 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Abs(m::Parameter(0))), + m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, LnSqrt) { + auto m = CreateNewVerifiedModule(); + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* sqrt = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, sqrt)); + + auto computation = m->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Sqrt(m::Parameter(0))))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( computation->root_instruction(), - GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::ConstantScalar(0.5)))); +} + +TEST_F(AlgebraicSimplifierTest, LnRsqrt) { + auto m = CreateNewVerifiedModule(); + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* rsqrt = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, rsqrt)); + + auto computation = m->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Rsqrt(m::Parameter(0))))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), + m::ConstantScalar(-0.5)))); } // Test that ln(exp(A)) is simplified to A @@ -5639,5 +5687,21 @@ TEST_F(AlgebraicSimplifierTest, MaxOfClamp) { GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2)))); } +TEST_F(AlgebraicSimplifierTest, SliceOfConcat) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[100,50] parameter(0) + p1 = f32[50,50] parameter(1) + c0 = f32[150,50] concatenate(p0, p1), dimensions={0} + ROOT s0 = f32[50,50] slice(c0), slice={[100:150], [0:50]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 131b50efc9c..de9c4f16efe 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -105,8 +105,8 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor { HloInstruction* operand, int64 feature_index, const std::function)>& add_instruction) { - auto elements_per_feature_u32 = add_instruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto elements_per_feature_s32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); for (int64 i = 0; i < operand->shape().rank(); ++i) { if (i == feature_index) { @@ -114,15 +114,15 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor { } auto dynamic_dimension_size = add_instruction(HloInstruction::CreateGetDimensionSize( - ShapeUtil::MakeShape(U32, {}), operand, i)); - elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, - dynamic_dimension_size, elements_per_feature_u32)); + ShapeUtil::MakeShape(S32, {}), operand, i)); + elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_s32)); } return HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), - elements_per_feature_u32); + elements_per_feature_s32); } // Current HloComputation instance the BatchNormExpander is diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index d72a91f45df..6bf745df968 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -234,8 +234,9 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset, int64 size) { - VLOG(4) << "Adding the following buffer to allocation #" << index() << " [" - << offset << ", " << size << "]: " << buffer; + VLOG(4) << "Adding the following buffer to allocation #" << index() + << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset, + buffer.ToShortString()); CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -291,6 +292,10 @@ BufferAllocationProto BufferAllocation::ToProto() const { return proto; } +static bool CompareHloValuesById(const HloValue* a, const HloValue* b) { + return a->id() < b->id(); +} + string BufferAllocation::ToString() const { string output; StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); @@ -319,15 +324,14 @@ string BufferAllocation::ToString() const { for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - absl::c_sort(sorted_buffers, [](const HloValue* a, const HloValue* b) { - return a->id() < b->id(); - }); + absl::c_sort(sorted_buffers, &CompareHloValuesById); for (const HloValue* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, absl::StrFormat( - " %s [%d,%d]: %s\n", buffer->ToString(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()))); + StrAppend(&output, + absl::StrFormat( + " value: %s (size=%d,offset=%d): %s\n", + buffer->ToShortString(), offset_size.size, offset_size.offset, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -715,8 +719,17 @@ string BufferAssignment::Stats::ToString() const { string BufferAssignment::ToString() const { string output; absl::StrAppend(&output, "BufferAssignment:\n"); + std::vector used_values; for (auto& allocation : allocations_) { absl::StrAppend(&output, allocation.ToString()); + for (const auto& p : allocation.assigned_buffers()) { + used_values.push_back(p.first); + } + } + absl::StrAppend(&output, "\nUsed values:\n"); + absl::c_sort(used_values, &CompareHloValuesById); + for (const HloValue* value : used_values) { + absl::StrAppend(&output, value->ToString()); } return output; } @@ -808,12 +821,18 @@ bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1, auto operand_value = buffer1; auto user_value = buffer2; if (!can_share_as_operand(user_value, operand_value)) { + VLOG(4) << "End of live range of " << buffer1->ToShortString() + << " is equal to the start of live range of " + << buffer2->ToShortString() << ", buffer cannot be shared."; return true; } } else if (live_range_2.end == live_range_1.start) { auto operand_value = buffer2; auto user_value = buffer1; if (!can_share_as_operand(user_value, operand_value)) { + VLOG(4) << "End of live range of " << buffer2->ToShortString() + << " is equal to the start of live range of " + << buffer1->ToShortString() << ", buffer cannot be shared."; return true; } } else { @@ -898,6 +917,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, for (const HloValue* new_value : hlo_buffer.values()) { if (assignment->hlo_live_range().total_order_scheduled()) { if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " live range interferes with " + << new_value->ToShortString(); return false; } } else if (assignment->hlo_ordering().MayInterfere( @@ -1235,20 +1257,24 @@ Status BufferAssigner::AssignBuffersForComputations( return a_size > b_size; // use ">" for decreasing size. } + // Values which live out the computation lifetime will be assigned + // first, as they can not be given to the heap simulator. const bool a_live_out = alias_analysis.BufferLivesOut(*a); const bool b_live_out = alias_analysis.BufferLivesOut(*b); if (a_live_out != b_live_out) { return a_live_out; } + + // Process values in the reverse postorder, since we have to start + // with the last value. auto compare = [&post_order_position](const HloValue* value1, const HloValue* value2) { - return post_order_position.at(value1->instruction()) < + return post_order_position.at(value1->instruction()) > post_order_position.at(value2->instruction()); }; const HloValue* a_min = *absl::c_min_element(a->values(), compare); const HloValue* b_min = *absl::c_min_element(b->values(), compare); - return post_order_position.at(a_min->instruction()) < - post_order_position.at(b_min->instruction()); + return compare(a_min, b_min); }); std::vector allocation_indices; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 1c985485d43..3ec5c1e3d49 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -2432,6 +2432,50 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } +TEST_F(BufferAssignmentTest, ProcessingOrderTest) { + const char* hlo_text = R"( +HloModule nested_convolution + +ENTRY %nested_convolution (param: f32[200,32,32,1]) -> f32[200,32,32,1] { + %param = f32[200,32,32,1]{3,2,1,0} parameter(0) + %bitcast = f32[200,32,32,1]{2,1,3,0} bitcast(f32[200,32,32,1]{3,2,1,0} %param) + %one = f32[] constant(1) + %conv_window = f32[3,3,1,1]{1,0,2,3} broadcast(f32[] %one), dimensions={} + %conv0 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %bitcast, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.6 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv0), index=0 + %conv1 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.6, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.7 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv1), index=0 + %conv2 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.7, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.8 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv2), index=0 + %conv3 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.8, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.9 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv3), index=0 + %conv4 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.9, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.10 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv4), index=0 + %conv5 = (f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) custom-call(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.10, f32[3,3,1,1]{1,0,2,3} %conv_window), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{algorithm:1,tensor_ops_enabled:true,conv_result_scale:1}" + %get-tuple-element.11 = f32[200,32,32,1]{2,1,3,0} get-tuple-element((f32[200,32,32,1]{2,1,3,0}, u8[6152]{0}) %conv5), index=0 + ROOT %bitcast.1 = f32[200,32,32,1]{3,2,1,0} bitcast(f32[200,32,32,1]{2,1,3,0} %get-tuple-element.11) +} +)"; + + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsFromFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(hlo_text, config)); + + std::unique_ptr buffers = RunBufferAssignment(m.get()); + + // We should occupy strictly less size than 4 * size of the buffer required + // for convolution. + int64 conv_size_bytes = 200 * 32 * 32 * 4; + EXPECT_LT(buffers->GetStats().total_allocation_bytes, conv_size_bytes * 4); +} + TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 20ebafcf780..cfcf059ba5f 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include #include #include @@ -474,8 +475,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { new_convolution))); } } else { - int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); - int64 output_feature = filter->shape().dimensions(kernel_output_feature_dim); @@ -487,11 +486,62 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the // additional spatial dimension. The generated convolution output will be // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. - - if (group_count == output_feature && !filter_expansion_) { + // We only do this for b0..0f or f0..0b dimension labels on activations. + const int64 input_feature_dim = dim_numbers.input_feature_dimension(); + const int64 input_batch_dim = dim_numbers.input_batch_dimension(); + const int64 activations_dimension_count = + convolution->operand(0)->shape().dimensions().size(); + if (group_count == output_feature && !filter_expansion_ && + ((input_feature_dim == 0 && + input_batch_dim == activations_dimension_count - 1) || + (input_batch_dim == 0 && + input_feature_dim == activations_dimension_count - 1))) { auto filter = convolution->mutable_operand(1); auto activation = convolution->mutable_operand(0); + // We want b0..0f logical dimensions on activations. If they are f0..0b + // instead, we transpose the activations to have the right dimension + // ordering. + if (input_feature_dim < input_batch_dim) { + // Generate the required shape for activations by swapping batch and + // feature dimension sizes. + Shape new_act_shape = activation->shape(); + new_act_shape.set_dimensions(dim_numbers.input_feature_dimension(), + activation->shape().dimensions( + dim_numbers.input_batch_dimension())); + new_act_shape.set_dimensions( + dim_numbers.input_batch_dimension(), + activation->shape().dimensions( + dim_numbers.input_feature_dimension())); + + // Generate dimension mapping. + std::vector transpose_dims(new_act_shape.dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + std::iter_swap(transpose_dims.begin(), transpose_dims.end() - 1); + + // Transpose the activations. Change the convolution input. + auto transposed_activations = + computation_->AddInstruction(HloInstruction::CreateTranspose( + new_act_shape, activation, transpose_dims)); + TF_CHECK_OK(convolution->ReplaceOperandWithDifferentShape( + 0, transposed_activations)); + + const int64 old_feature_dim = dim_numbers.input_feature_dimension(); + const int64 old_batch_dim = dim_numbers.input_batch_dimension(); + + // Rectify the convolution dimension numbers. + dim_numbers.set_input_feature_dimension(old_batch_dim); + dim_numbers.set_input_batch_dimension(old_feature_dim); + convolution->set_convolution_dimension_numbers(dim_numbers); + + // Update the data structures we'd use. + dim_numbers = convolution->convolution_dimension_numbers(); + activation = convolution->mutable_operand(0); + } + + const int64 activation_input_feature_dim = + dim_numbers.input_feature_dimension(); + // Add spatial dimension to the activation, and reshape. Shape reshaped_activation_shape = activation->shape(); ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); @@ -534,12 +584,16 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { /*batch_group_count=*/1, new_window, dim_numbers, convolution->precision_config())); + VLOG(2) << "New convolution " << new_convolution->ToString(); + // Delete the extra spatial dimension, and reshape. Shape reshaped_convolution_shape = ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); auto reshaped_convolution = HloInstruction::CreateReshape( reshaped_convolution_shape, new_convolution); + VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString(); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(reshaped_convolution))); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 7606e31b24d..85cf5e70f55 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -637,19 +637,6 @@ class CopyRemover { DCHECK(src != nullptr); DCHECK(dest != nullptr); - auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { - VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; - if (LiveRangeBefore(a, b)) { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is before " << b.value->ToShortString(); - return true; - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); - return false; - } - }; - VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); VLOG(3) << "Source buffer values: " << ValueListToString(src); VLOG(3) << "Dest buffer values: " << ValueListToString(dest); @@ -715,7 +702,7 @@ class CopyRemover { ValueNode* next_dest = Next(*dest); if (next_dest != nullptr) { // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!is_live_range_before(*src, *next_dest)) { + if (!LiveRangeBefore(*src, *next_dest)) { return false; } } @@ -725,7 +712,7 @@ class CopyRemover { // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. ValueNode* last_dest = dest->prev; DCHECK(IsTail(*last_dest)); - if (!is_live_range_before(*last_dest, *next_src)) { + if (!LiveRangeBefore(*last_dest, *next_src)) { return false; } } @@ -754,13 +741,13 @@ class CopyRemover { DCHECK(prev_dest != nullptr); ValueNode* first_src = src->next; DCHECK(IsHead(*first_src)); - if (!is_live_range_before(*prev_dest, *first_src)) { + if (!LiveRangeBefore(*prev_dest, *first_src)) { // Live range of value d_{y-1} is not before s_0. return false; } ValueNode* next_dest = Next(*dest); if (next_dest != nullptr) { - if (!is_live_range_before(*src, *next_dest)) { + if (!LiveRangeBefore(*src, *next_dest)) { // Live range of value s_n is not before d_{y+1}. return false; } @@ -829,19 +816,30 @@ class CopyRemover { // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - if (a.uses.empty()) { - VLOG(2) << "Empty uses for " << *a.value; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(2) << "Checking use " << *use << " against " << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Use " << *use << " is NOT before " << *b.value; - return false; + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; + bool is_live_range_before = [&] { + if (a.uses.empty()) { + VLOG(2) << "Empty uses for " << *a.value; + return ordering_.IsDefinedBefore(*a.value, *b.value); } - VLOG(2) << "Use " << *use << " is before " << *b.value; + for (const HloUse* use : a.uses) { + VLOG(3) << "Checking use " << *use << " against " << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; + return false; + } + VLOG(3) << "Use " << *use << " is before " << *b.value; + } + return true; + }(); + if (is_live_range_before) { + VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " + << b.value->ToShortString(); + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); } - return true; + return is_live_range_before; } // Returns whether 'node' is the last node in its list. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7371c79b39..7b15d49cc47 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -344,6 +344,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); + // Layout assignment uses alias analysis, which requires the call graph to be + // flattened. + pipeline.AddPass(); pipeline.AddPass( module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout, target_machine_features); @@ -407,7 +410,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( // before (and sometime after) copy insertion, to avoid dead code from // interfering with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); return pipeline.Run(module).status(); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 1f7d41c7b94..e02a58210c2 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -449,10 +449,15 @@ Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size, + DimensionConstraint constraint) -> Status { HloInstruction* reshape = hlo; + TF_RET_CHECK(reshape->shape().rank() > 0) + << "Reshaping a dynamic dimension into a scalar, which has " + "undefined behavior. The offending instruction is: " + << reshape->ToString(); // Reshape is supported as long as it is the most // major one and it is combining with other non-dynamic dimensions. const int64 output_most_major = reshape->shape().dimensions(0); @@ -463,7 +468,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { reshape->shape().dimensions(0) / operand->shape().dimensions(0); HloInstruction* multiplier_hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(multiplier))); + LiteralUtil::CreateR0(multiplier))); HloInstruction* new_dynamic_size = hlo->parent()->AddInstruction(HloInstruction::CreateBinary( @@ -638,7 +643,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { reshape->shape().dimensions(dynamic_dimension); HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(divisor))); + LiteralUtil::CreateR0(divisor))); HloInstruction* new_dynamic_size = hlo->parent()->AddInstruction(HloInstruction::CreateBinary( @@ -828,20 +833,13 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { int64 operand_index, HloInstruction* operand_dynamic_size, DimensionConstraint constraint) { if (operand_index == 0) { - return Unimplemented( - "Detects a dynamic dimension on the data input of scatter, which " - "is not supported: %s", - hlo->ToString()); - } - - const ScatterDimensionNumbers& scatter_dims = - hlo->scatter_dimension_numbers(); - if (operand_index == 1) { parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size, constraint); return Status::OK(); } + const ScatterDimensionNumbers& scatter_dims = + hlo->scatter_dimension_numbers(); if (operand_index == 2 && absl::c_linear_search(scatter_dims.update_window_dims(), dimension)) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 12af09fee4a..e8e89c8357b 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -164,6 +164,8 @@ class DynamicDimensionInference { // by a scalar instruction `size`. void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, HloInstruction* size, DimensionConstraint constraint) { + VLOG(1) << "Set dimension inst " << inst->name() << " index " + << index.ToString() << "@" << dim << " to " << size->ToString(); Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index); CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension"; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 7a13307ffbf..264263570cb 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -94,7 +94,7 @@ class DynamicDimensionInferenceTest : public HloTestBase { std::unique_ptr module_; std::unique_ptr inference_; - const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); }; TEST_F(DynamicDimensionInferenceTest, ParamTest) { @@ -557,7 +557,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) { EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr); const Literal& multiplier = inference_->GetDynamicSize(reshape, {}, 0)->operand(1)->literal(); - LiteralTestUtil::ExpectR0Equal(10, multiplier); + LiteralTestUtil::ExpectR0Equal(10, multiplier); } TEST_F(DynamicDimensionInferenceTest, GatherTest) { @@ -895,7 +895,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) { std::vector params; for (int i = 0; i < 2; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices"))); } auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( @@ -997,7 +997,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) { std::vector params; for (int i = 0; i < 2; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices"))); } auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 5fea5d823de..dc16ef4d9f9 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -78,9 +78,18 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, case HloOpcode::kSelectAndScatter: { return inst->mutable_operand(2); } + case HloOpcode::kScatter: { + if (operand_number != 1) { + return nullptr; + } + PrimitiveType indices_ptype = + inst->operand(operand_number)->shape().element_type(); + + return comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(indices_ptype))); + } case HloOpcode::kParameter: case HloOpcode::kGather: - case HloOpcode::kScatter: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kGetDimensionSize: @@ -128,17 +137,19 @@ StatusOr DynamicPadder::Run(HloModule* module) { for (HloInstruction* inst : computation->instructions()) { for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { - HloInstruction* operand = inst->mutable_operand(operand_num); + HloInstruction* original_operand = inst->mutable_operand(operand_num); + HloInstruction* operand = original_operand; if (!operand->shape().IsArray()) { continue; } for (int64 dim = 0; dim < operand->shape().rank(); ++dim) { HloInstruction* dynamic_size = - dynamic_dimension_inference.GetDynamicSize(operand, {}, dim); + dynamic_dimension_inference.GetDynamicSize(original_operand, {}, + dim); if (dynamic_size == nullptr) { continue; } - VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" + VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @" << dim; if (ShouldSkipPadOnOperand(inst, operand_num, dim)) { @@ -164,7 +175,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { // mask and pad value. // const Shape mask_shape = - ShapeUtil::ChangeElementType(operand->shape(), xla::U32); + ShapeUtil::ChangeElementType(operand->shape(), xla::S32); const Shape pred_shape = ShapeUtil::ChangeElementType(operand->shape(), xla::PRED); HloInstruction* iota = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 2963deaa317..4dfb93ee7d8 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -28,7 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -65,7 +68,7 @@ class DynamicPadderTest : public HloTestBase { } std::unique_ptr module_; - const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); }; TEST_F(DynamicPadderTest, ReduceTest) { @@ -212,5 +215,189 @@ TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) { EXPECT_THAT(output->operand(0), op::Parameter()); } +// Test that dynamic padder has the same result as if not padded. +class ExecutionTest : public HloTestBase { + protected: + std::unique_ptr GetHloModule(const string& hlo_text) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + std::unique_ptr module = + ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie(); + return module; + } +}; + +XLA_TEST_F(ExecutionTest, ScatterUpdate) { + // Test that scattering on indices=[2] is same as scattering on indices=[4] + // and dynamic dimension = 2 + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[INDICES_BOUND] parameter(1) + updates = s32[INDICES_BOUND,3] parameter(2) + dynamic_size = s32[] parameter(3) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + +} +)"; + const string hlo_text_not_padded = + absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}}); + auto module_not_padded = GetHloModule(hlo_text_not_padded); + + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal dynamic_size = LiteralUtil::CreateR0(2); + + Literal not_padded = + ExecuteAndTransfer(std::move(module_not_padded), + {&operand, &scatter_indices, &updates, &dynamic_size}); + + // Pad input to 4. + const string hlo_text_padded = + absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}}); + auto module_padded = GetHloModule(hlo_text_padded); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{2, {}, 0})); + // Pad the rest of input with garbage data. + Literal scatter_indices_padded = LiteralUtil::CreateR1({0, 2, 0, 4}); + Literal updates_padded = LiteralUtil::CreateR2( + {{10, 20, 30}, {70, 80, 90}, {30, 22, 11}, {-1, 20, -1}}); + DynamicPadder padder; + TF_CHECK_OK(padder.Run(module_padded.get()).status()); + Literal padded = ExecuteAndTransfer( + std::move(module_padded), + {&operand, &scatter_indices_padded, &updates_padded, &dynamic_size}); + + EXPECT_EQ(padded, not_padded); +} + +XLA_TEST_F(ExecutionTest, ScatterUpdateF32) { + // Test that scattering on indices=[2] is same as scattering on indices=[4] + // and dynamic dimension = 2 + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + ROOT rhs = f32[] parameter(1) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + dynamic_size = s32[] parameter(3) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=update_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + +} +)"; + + auto module_not_padded = GetHloModule(hlo_text); + + Literal operand = LiteralUtil::CreateR2( + {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = + LiteralUtil::CreateR2({{10.0, 20.0, 30.0}, {70.0, 80.0, 90.0}}); + // Dynamic Size is 1, pad to 2 + Literal dynamic_size = LiteralUtil::CreateR0(1); + + auto module_padded = GetHloModule(hlo_text); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{2, {}, 0})); + DynamicPadder padder; + TF_CHECK_OK(padder.Run(module_padded.get()).status()); + Literal not_padded = + ExecuteAndTransfer(std::move(module_padded), + {&operand, &scatter_indices, &updates, &dynamic_size}); + // Although we have two indices, only the first element is updated because of + // padding. + EXPECT_EQ(LiteralUtil::CreateR2( + {{10.0, 20.0, 30.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}), + not_padded); +} + +XLA_TEST_F(ExecutionTest, TwoDimensionReduce) { + // Test that reducing on operand=[2,2] is same as reducing on operand=[4,4] + // and dynamic dimension = 2 + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[INDICES_BOUND, INDICES_BOUND] parameter(0) + dynamic_size = s32[] parameter(1) + const = s32[] constant(0) + ROOT reduce = s32[] reduce(param, const), + dimensions={0, 1}, + to_apply=update_s32 +} +)"; + const string hlo_text_not_padded = + absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "2"}}); + auto module_not_padded = GetHloModule(hlo_text_not_padded); + + Literal operand = LiteralUtil::CreateR2({{1, 2}, {4, 5}}); + Literal dynamic_size = LiteralUtil::CreateR0(2); + + Literal not_padded = ExecuteAndTransfer(std::move(module_not_padded), + {&operand, &dynamic_size}); + + // Pad input to 4. + const string hlo_text_padded = + absl::StrReplaceAll(hlo_text, {{"INDICES_BOUND", "4"}}); + auto module_padded = GetHloModule(hlo_text_padded); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + TF_CHECK_OK(module_padded->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + // Pad the rest of input with garbage data. + Literal operand_padded = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {4, 5, 6, 7}, {1, 2, 3, 4}, {4, 5, 6, 7}}); + DynamicPadder padder; + TF_CHECK_OK(padder.Run(module_padded.get()).status()); + Literal padded = ExecuteAndTransfer(std::move(module_padded), + {&operand_padded, &dynamic_size}); + + EXPECT_EQ(padded, not_padded); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 21476c7e921..7b871951ed0 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -456,6 +456,7 @@ tf_cc_test( cc_library( name = "gpu_executable", srcs = [ + "cholesky_thunk.cc", "collective_permute_thunk.cc", "conditional_thunk.cc", "convolution_thunk.cc", @@ -476,10 +477,9 @@ cc_library( "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", - ] + if_cuda_is_configured([ - "cholesky_thunk.cc", - ]), + ], hdrs = [ + "cholesky_thunk.h", "collective_permute_thunk.h", "conditional_thunk.h", "convolution_thunk.h", @@ -500,12 +500,11 @@ cc_library( "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", - ] + if_cuda_is_configured([ - "cholesky_thunk.h", - ]), + ], deps = [ ":backend_configs", ":buffer_allocations", + ":cusolver_context", ":cudnn_conv_runner", ":gpu_debug_info_manager", ":gpu_types", @@ -559,7 +558,6 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ] + if_cuda_is_configured([ - ":cusolver_context", "//tensorflow/stream_executor/cuda:cuda_stream", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -633,7 +631,7 @@ cc_library( "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory_allocator", - "//tensorflow/stream_executor/cuda:redzone_allocator", + "//tensorflow/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/types:optional", ], ) @@ -664,7 +662,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:device_memory_allocator", - "//tensorflow/stream_executor/cuda:redzone_allocator", + "//tensorflow/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -718,6 +716,7 @@ tf_cc_test( deps = [ ":cudnn_conv_rewriter", ":ir_emission_utils", + "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", @@ -731,21 +730,22 @@ tf_cc_test( cc_library( name = "cusolver_context", - srcs = ["cusolver_context.cc"], + srcs = if_cuda_is_configured(["cusolver_context.cc"]), hdrs = ["cusolver_context.h"], deps = [ # LINT.IfChange "@local_config_cuda//cuda:cublas_headers", # LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers) - "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", "//tensorflow/stream_executor/cuda:cusolver_lib", - ], + ]), ) cc_library( @@ -950,11 +950,12 @@ cc_library( ) cc_library( - name = "cudnn_conv_pad_for_tensor_cores", - srcs = ["cudnn_conv_pad_for_tensor_cores.cc"], - hdrs = ["cudnn_conv_pad_for_tensor_cores.h"], + name = "cudnn_pad_for_convolutions", + srcs = ["cudnn_pad_for_convolutions.cc"], + hdrs = ["cudnn_pad_for_convolutions.h"], deps = [ ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -964,10 +965,10 @@ cc_library( ) tf_cc_test( - name = "cudnn_conv_pad_for_tensor_cores_test", - srcs = ["cudnn_conv_pad_for_tensor_cores_test.cc"], + name = "cudnn_pad_for_convolutions_test", + srcs = ["cudnn_pad_for_convolutions_test.cc"], deps = [ - ":cudnn_conv_pad_for_tensor_cores", + ":cudnn_pad_for_convolutions", ":ir_emission_utils", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -1053,9 +1054,9 @@ cc_library( deps = [ ":alias_passthrough_params", ":cudnn_batchnorm_rewriter", - ":cudnn_conv_algorithm_picker", ":cudnn_conv_padding_legalization", ":cudnn_conv_rewriter", + ":cudnn_pad_for_convolutions", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -1156,10 +1157,10 @@ cc_library( deps = [ ":cublas_gemm_pad_for_tensor_cores", ":cudnn_conv_algorithm_picker", - ":cudnn_conv_pad_for_tensor_cores", ":cudnn_conv_padding_legalization", ":cudnn_conv_rewriter", ":cudnn_fused_conv_rewriter", + ":cudnn_pad_for_convolutions", ":cusolver_rewriter", ":gemm_algorithm_picker", ":gemm_rewriter", @@ -1190,7 +1191,7 @@ cc_library( "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/cuda:cuda_diagnostics", - "//tensorflow/stream_executor/cuda:ptxas_utils", + "//tensorflow/stream_executor/gpu:asm_compiler", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/types:optional", ], @@ -1413,7 +1414,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:kernel_spec", - "//tensorflow/stream_executor/cuda:ptxas_utils", + "//tensorflow/stream_executor/gpu:asm_compiler", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1509,9 +1510,15 @@ cc_library( tf_cc_test( name = "cudnn_fused_conv_rewriter_test", srcs = ["cudnn_fused_conv_rewriter_test.cc"], - tags = tf_cuda_tests_tags(), + tags = [ + "noasan", + "nomsan", + "requires-gpu-sm70", + ], deps = [ + ":cudnn_fused_conv_rewriter", ":ir_emission_utils", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 30108315e4d..37095adf7c6 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -34,7 +34,7 @@ namespace gpu { static constexpr double kTolerance = 0.1f; -// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length +// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length // buffer_length where the relative error does not exceed the passed // rel_error_threshold. Write the number of mismatches into out parameter // mismatch_count. @@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f; // // #include // extern "C" { // avoid name mangling -// __device__ float canonicalize(float input) { +// __device__ float __xla_buffer_comparator_canonicalize(float input) { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } -// + +// __device__ float __xla_buffer_comparator_extract_int8(int pack) { +// // Extract the lower 8 bits from pack and convert it to float +// const unsigned int bit_mask = 0xff; +// unsigned int bits = pack & bit_mask; +// char* int8_ptr = (char*)&bits; +// return __int2float_rn(*int8_ptr); +// } + // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f; // if (idx >= buffer_length) return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); -// elem_a = canonicalize(elem_a); -// elem_b = canonicalize(elem_b); +// elem_a = __xla_buffer_comparator_canonicalize(elem_a); +// elem_b = __xla_buffer_comparator_canonicalize(elem_b); // if (isnan(elem_a) && isnan(elem_b)) return; // float rel_error = abs(elem_a - elem_b) // / (max(abs(elem_a), abs(elem_b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -102,234 +110,440 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } + +// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// int pack_a = buffer_a[idx]; +// int pack_b = buffer_b[idx]; +// for(int i = 0; i < 4; ++i) { +// float elem_a = __xla_buffer_comparator_extract_int8(pack_a); +// float elem_b = __xla_buffer_comparator_extract_int8(pack_b); +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// pack_a >>= 8; +// pack_b >>= 8; +// } +// } // } // end extern declaration. static const char* buffer_compare_ptx = R"( .version 4.2 .target sm_30 .address_size 64 + // .globl __xla_fp16_comparison + .visible .entry __xla_fp16_comparison( - .param .u64 __xla_fp16_comparison_param_0, - .param .u64 __xla_fp16_comparison_param_1, - .param .f32 __xla_fp16_comparison_param_2, - .param .u64 __xla_fp16_comparison_param_3, - .param .u64 __xla_fp16_comparison_param_4 + .param .u64 __xla_fp16_comparison_param_0, + .param .u64 __xla_fp16_comparison_param_1, + .param .f32 __xla_fp16_comparison_param_2, + .param .u64 __xla_fp16_comparison_param_3, + .param .u64 __xla_fp16_comparison_param_4 ) { - .reg .pred %p<10>; - .reg .b16 %rs<3>; - .reg .f32 %f<20>; - .reg .b32 %r<6>; - .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB7_4; - ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 1; - add.s64 %rd10, %rd3, %rd9; - ld.global.u16 %rs1, [%rd10]; - // begin inline asm - { cvt.f32.f16 %f6, %rs1;} + .reg .pred %p<9>; + .reg .b16 %rs<3>; + .reg .f32 %f<28>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; - // end inline asm - add.s64 %rd11, %rd2, %rd9; - ld.global.u16 %rs2, [%rd11]; - // begin inline asm - { cvt.f32.f16 %f7, %rs2;} - // end inline asm - abs.f32 %f8, %f6; - setp.gtu.f32 %p2, %f8, 0f7F800000; - min.f32 %f9, %f6, 0f477FE100; - max.f32 %f10, %f9, 0fC77FE100; - selp.f32 %f1, %f6, %f10, %p2; - abs.f32 %f11, %f7; - setp.gtu.f32 %p3, %f11, 0f7F800000; - min.f32 %f12, %f7, 0f477FE100; - max.f32 %f13, %f12, 0fC77FE100; - selp.f32 %f2, %f7, %f13, %p3; - abs.f32 %f3, %f1; - setp.gtu.f32 %p4, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p5, %f4, 0f7F800000; - and.pred %p6, %p4, %p5; - @%p6 bra LBB7_4; - ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; - sub.f32 %f14, %f1, %f2; - abs.f32 %f15, %f14; - max.f32 %f16, %f3, %f4; - add.f32 %f17, %f16, 0f3F800000; - div.rn.f32 %f18, %f15, %f17; - setp.leu.f32 %p7, %f18, %f5; - abs.f32 %f19, %f18; - setp.le.f32 %p8, %f19, 0f7F800000; - and.pred %p9, %p7, %p8; - @%p9 bra LBB7_4; - ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r5, [%rd1], 1; -LBB7_4: - ret; + ld.param.u64 %rd1, [__xla_fp16_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp16_comparison_param_1]; + ld.param.f32 %f10, [__xla_fp16_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp16_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp16_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB0_9; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 2; + add.s64 %rd8, %rd6, %rd7; + ld.global.u16 %rs1, [%rd8]; + // inline asm + { cvt.f32.f16 %f26, %rs1;} + + // inline asm + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.u16 %rs2, [%rd10]; + // inline asm + { cvt.f32.f16 %f27, %rs2;} + + // inline asm + abs.f32 %f13, %f26; + setp.gtu.f32 %p2, %f13, 0f7F800000; + @%p2 bra BB0_3; + + mov.f32 %f14, 0f477FE100; + min.f32 %f15, %f26, %f14; + mov.f32 %f16, 0fC77FE100; + max.f32 %f26, %f16, %f15; + +BB0_3: + abs.f32 %f17, %f27; + setp.gtu.f32 %p3, %f17, 0f7F800000; + @%p3 bra BB0_5; + + mov.f32 %f18, 0f477FE100; + min.f32 %f19, %f27, %f18; + mov.f32 %f20, 0fC77FE100; + max.f32 %f27, %f20, %f19; + +BB0_5: + abs.f32 %f7, %f26; + setp.gtu.f32 %p4, %f7, 0f7F800000; + abs.f32 %f8, %f27; + setp.gtu.f32 %p5, %f8, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra BB0_9; + + sub.f32 %f21, %f26, %f27; + abs.f32 %f22, %f21; + max.f32 %f23, %f7, %f8; + add.f32 %f24, %f23, 0f3F800000; + div.rn.f32 %f9, %f22, %f24; + setp.gt.f32 %p7, %f9, %f10; + @%p7 bra BB0_8; + + abs.f32 %f25, %f9; + setp.le.f32 %p8, %f25, 0f7F800000; + @%p8 bra BB0_9; + +BB0_8: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r5, [%rd11], 1; + +BB0_9: + ret; } - // .globl __xla_fp32_comparison + + // .globl __xla_fp32_comparison .visible .entry __xla_fp32_comparison( - .param .u64 __xla_fp32_comparison_param_0, - .param .u64 __xla_fp32_comparison_param_1, - .param .f32 __xla_fp32_comparison_param_2, - .param .u64 __xla_fp32_comparison_param_3, - .param .u64 __xla_fp32_comparison_param_4 + .param .u64 __xla_fp32_comparison_param_0, + .param .u64 __xla_fp32_comparison_param_1, + .param .f32 __xla_fp32_comparison_param_2, + .param .u64 __xla_fp32_comparison_param_3, + .param .u64 __xla_fp32_comparison_param_4 ) { - .reg .pred %p<12>; - .reg .f32 %f<12>; - .reg .b32 %r<9>; - .reg .b64 %rd<12>; + .reg .pred %p<10>; + .reg .b16 %rs<3>; + .reg .f32 %f<13>; + .reg .b32 %r<10>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB8_6; - ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 2; - add.s64 %rd10, %rd3, %rd9; - ld.global.f32 %f1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f32 %f2, [%rd11]; - abs.f32 %f3, %f1; - setp.gtu.f32 %p2, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p3, %f4, 0f7F800000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB8_6; - setp.neu.f32 %p5, %f3, 0f7F800000; - setp.neu.f32 %p6, %f4, 0f7F800000; - or.pred %p7, %p5, %p6; - @%p7 bra LBB8_4; - mov.b32 %r5, %f1; - mov.b32 %r6, %f2; - xor.b32 %r7, %r6, %r5; - setp.gt.s32 %p8, %r7, -1; - @%p8 bra LBB8_6; -LBB8_4: - ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; - sub.f32 %f6, %f1, %f2; - abs.f32 %f7, %f6; - max.f32 %f8, %f3, %f4; - add.f32 %f9, %f8, 0f3F800000; - div.rn.f32 %f10, %f7, %f9; - setp.leu.f32 %p9, %f10, %f5; - abs.f32 %f11, %f10; - setp.le.f32 %p10, %f11, 0f7F800000; - and.pred %p11, %p9, %p10; - @%p11 bra LBB8_6; - ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r8, [%rd1], 1; -LBB8_6: - ret; + ld.param.u64 %rd1, [__xla_fp32_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp32_comparison_param_1]; + ld.param.f32 %f6, [__xla_fp32_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp32_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp32_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB1_8; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 4; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f32 %f1, [%rd10]; + ld.global.f32 %f2, [%rd8]; + abs.f32 %f3, %f2; + setp.le.f32 %p2, %f3, 0f7F800000; + @%p2 bra BB1_3; + + abs.f32 %f7, %f1; + setp.gtu.f32 %p3, %f7, 0f7F800000; + @%p3 bra BB1_8; + +BB1_3: + setp.neu.f32 %p4, %f3, 0f7F800000; + abs.f32 %f4, %f1; + setp.neu.f32 %p5, %f4, 0f7F800000; + or.pred %p6, %p4, %p5; + @%p6 bra BB1_5; + + mov.b32 %r5, %f2; + shr.u32 %r6, %r5, 31; + cvt.u16.u32 %rs1, %r6; + mov.b32 %r7, %f1; + shr.u32 %r8, %r7, 31; + cvt.u16.u32 %rs2, %r8; + setp.eq.s16 %p7, %rs1, %rs2; + @%p7 bra BB1_8; + +BB1_5: + sub.f32 %f8, %f2, %f1; + abs.f32 %f9, %f8; + max.f32 %f10, %f3, %f4; + add.f32 %f11, %f10, 0f3F800000; + div.rn.f32 %f5, %f9, %f11; + setp.gt.f32 %p8, %f5, %f6; + @%p8 bra BB1_7; + + abs.f32 %f12, %f5; + setp.le.f32 %p9, %f12, 0f7F800000; + @%p9 bra BB1_8; + +BB1_7: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r9, [%rd11], 1; + +BB1_8: + ret; } - // .globl __xla_fp64_comparison + + // .globl __xla_fp64_comparison .visible .entry __xla_fp64_comparison( - .param .u64 __xla_fp64_comparison_param_0, - .param .u64 __xla_fp64_comparison_param_1, - .param .f32 __xla_fp64_comparison_param_2, - .param .u64 __xla_fp64_comparison_param_3, - .param .u64 __xla_fp64_comparison_param_4 + .param .u64 __xla_fp64_comparison_param_0, + .param .u64 __xla_fp64_comparison_param_1, + .param .f32 __xla_fp64_comparison_param_2, + .param .u64 __xla_fp64_comparison_param_3, + .param .u64 __xla_fp64_comparison_param_4 ) { - .reg .pred %p<16>; - .reg .f32 %f<2>; - .reg .b32 %r<13>; - .reg .f64 %fd<12>; - .reg .b64 %rd<12>; + .reg .pred %p<11>; + .reg .b16 %rs<3>; + .reg .f32 %f<2>; + .reg .b32 %r<14>; + .reg .f64 %fd<13>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; - mov.u32 %r2, %tid.x; - mov.u32 %r3, %ctaid.x; - mov.u32 %r4, %ntid.x; - mad.lo.s32 %r5, %r4, %r3, %r2; - cvt.s64.s32 %rd4, %r5; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB9_6; - ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 3; - add.s64 %rd10, %rd3, %rd9; - ld.global.f64 %fd1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f64 %fd2, [%rd11]; - abs.f64 %fd3, %fd1; - setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; - abs.f64 %fd4, %fd2; - setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB9_6; - { - .reg .b32 %temp; - mov.b64 {%r6, %temp}, %fd1; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r1}, %fd1; - } - and.b32 %r7, %r1, 2147483647; - setp.ne.s32 %p5, %r7, 2146435072; - setp.ne.s32 %p6, %r6, 0; - or.pred %p7, %p6, %p5; - @%p7 bra LBB9_4; - { - .reg .b32 %temp; - mov.b64 {%r8, %temp}, %fd2; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r9}, %fd2; - } - and.b32 %r10, %r9, 2147483647; - setp.eq.s32 %p8, %r10, 2146435072; - setp.eq.s32 %p9, %r8, 0; - and.pred %p10, %p8, %p9; - xor.b32 %r11, %r9, %r1; - setp.gt.s32 %p11, %r11, -1; - and.pred %p12, %p11, %p10; - @%p12 bra LBB9_6; -LBB9_4: - ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; - sub.f64 %fd5, %fd1, %fd2; - abs.f64 %fd6, %fd5; - max.f64 %fd7, %fd3, %fd4; - add.f64 %fd8, %fd7, 0d3FF0000000000000; - div.rn.f64 %fd9, %fd6, %fd8; - cvt.f64.f32 %fd10, %f1; - setp.leu.f64 %p13, %fd9, %fd10; - abs.f64 %fd11, %fd9; - setp.le.f64 %p14, %fd11, 0d7FF0000000000000; - and.pred %p15, %p13, %p14; - @%p15 bra LBB9_6; - ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r12, [%rd1], 1; -LBB9_6: - ret; + + ld.param.u64 %rd1, [__xla_fp64_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp64_comparison_param_1]; + ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp64_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp64_comparison_param_4]; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB2_11; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f64 %fd1, [%rd10]; + ld.global.f64 %fd2, [%rd8]; + abs.f64 %fd3, %fd2; + setp.le.f64 %p2, %fd3, 0d7FF0000000000000; + @%p2 bra BB2_3; + + abs.f64 %fd5, %fd1; + setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000; + @%p3 bra BB2_11; + +BB2_3: + { + .reg .b32 %temp; + mov.b64 {%temp, %r2}, %fd2; + } + and.b32 %r7, %r2, 2147483647; + setp.ne.s32 %p4, %r7, 2146435072; + @%p4 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r8, %temp}, %fd2; + } + setp.ne.s32 %p5, %r8, 0; + @%p5 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%temp, %r3}, %fd1; + } + and.b32 %r9, %r3, 2147483647; + setp.ne.s32 %p6, %r9, 2146435072; + @%p6 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r10, %temp}, %fd1; + } + setp.ne.s32 %p7, %r10, 0; + @%p7 bra BB2_8; + + shr.u32 %r11, %r2, 31; + cvt.u16.u32 %rs1, %r11; + shr.u32 %r12, %r3, 31; + cvt.u16.u32 %rs2, %r12; + setp.eq.s16 %p8, %rs1, %rs2; + @%p8 bra BB2_11; + +BB2_8: + sub.f64 %fd6, %fd2, %fd1; + abs.f64 %fd7, %fd6; + abs.f64 %fd8, %fd1; + max.f64 %fd9, %fd3, %fd8; + add.f64 %fd10, %fd9, 0d3FF0000000000000; + div.rn.f64 %fd4, %fd7, %fd10; + cvt.f64.f32 %fd11, %f1; + setp.gt.f64 %p9, %fd4, %fd11; + @%p9 bra BB2_10; + + abs.f64 %fd12, %fd4; + setp.le.f64 %p10, %fd12, 0d7FF0000000000000; + @%p10 bra BB2_11; + +BB2_10: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r13, [%rd11], 1; + +BB2_11: + ret; +} + + // .globl __xla_int8_comparison +.visible .entry __xla_int8_comparison( + .param .u64 __xla_int8_comparison_param_0, + .param .u64 __xla_int8_comparison_param_1, + .param .f32 __xla_int8_comparison_param_2, + .param .u64 __xla_int8_comparison_param_3, + .param .u64 __xla_int8_comparison_param_4 +) +{ + .reg .pred %p<10>; + .reg .f32 %f<42>; + .reg .b32 %r<23>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_int8_comparison_param_0]; + ld.param.u64 %rd3, [__xla_int8_comparison_param_1]; + ld.param.f32 %f5, [__xla_int8_comparison_param_2]; + ld.param.u64 %rd4, [__xla_int8_comparison_param_3]; + ld.param.u64 %rd5, [__xla_int8_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd5; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd6, %r1; + setp.ge.u64 %p1, %rd6, %rd4; + @%p1 bra BB3_13; + + cvta.to.global.u64 %rd7, %rd2; + mul.wide.s32 %rd8, %r1, 4; + add.s64 %rd9, %rd7, %rd8; + cvta.to.global.u64 %rd10, %rd3; + add.s64 %rd11, %rd10, %rd8; + ld.global.u32 %r2, [%rd9]; + cvt.s32.s8 %r7, %r2; + cvt.rn.f32.s32 %f6, %r7; + ld.global.u32 %r3, [%rd11]; + cvt.s32.s8 %r8, %r3; + cvt.rn.f32.s32 %f7, %r8; + sub.f32 %f8, %f6, %f7; + abs.f32 %f9, %f8; + abs.f32 %f10, %f6; + abs.f32 %f11, %f7; + max.f32 %f12, %f10, %f11; + add.f32 %f13, %f12, 0f3F800000; + div.rn.f32 %f1, %f9, %f13; + setp.gt.f32 %p2, %f1, %f5; + @%p2 bra BB3_3; + + abs.f32 %f14, %f1; + setp.le.f32 %p3, %f14, 0f7F800000; + @%p3 bra BB3_4; + +BB3_3: + atom.global.add.u32 %r9, [%rd1], 1; + +BB3_4: + shr.u32 %r10, %r3, 8; + shr.u32 %r11, %r2, 8; + cvt.s32.s8 %r12, %r11; + cvt.rn.f32.s32 %f15, %r12; + cvt.s32.s8 %r13, %r10; + cvt.rn.f32.s32 %f16, %r13; + sub.f32 %f17, %f15, %f16; + abs.f32 %f18, %f17; + abs.f32 %f19, %f15; + abs.f32 %f20, %f16; + max.f32 %f21, %f19, %f20; + add.f32 %f22, %f21, 0f3F800000; + div.rn.f32 %f2, %f18, %f22; + setp.gt.f32 %p4, %f2, %f5; + @%p4 bra BB3_6; + + abs.f32 %f23, %f2; + setp.le.f32 %p5, %f23, 0f7F800000; + @%p5 bra BB3_7; + +BB3_6: + atom.global.add.u32 %r14, [%rd1], 1; + +BB3_7: + shr.u32 %r15, %r3, 16; + shr.u32 %r16, %r2, 16; + cvt.s32.s8 %r17, %r16; + cvt.rn.f32.s32 %f24, %r17; + cvt.s32.s8 %r18, %r15; + cvt.rn.f32.s32 %f25, %r18; + sub.f32 %f26, %f24, %f25; + abs.f32 %f27, %f26; + abs.f32 %f28, %f24; + abs.f32 %f29, %f25; + max.f32 %f30, %f28, %f29; + add.f32 %f31, %f30, 0f3F800000; + div.rn.f32 %f3, %f27, %f31; + setp.gt.f32 %p6, %f3, %f5; + @%p6 bra BB3_9; + + abs.f32 %f32, %f3; + setp.le.f32 %p7, %f32, 0f7F800000; + @%p7 bra BB3_10; + +BB3_9: + atom.global.add.u32 %r19, [%rd1], 1; + +BB3_10: + shr.s32 %r20, %r2, 24; + cvt.rn.f32.s32 %f33, %r20; + shr.s32 %r21, %r3, 24; + cvt.rn.f32.s32 %f34, %r21; + sub.f32 %f35, %f33, %f34; + abs.f32 %f36, %f35; + abs.f32 %f37, %f33; + abs.f32 %f38, %f34; + max.f32 %f39, %f37, %f38; + add.f32 %f40, %f39, 0f3F800000; + div.rn.f32 %f4, %f36, %f40; + setp.gt.f32 %p8, %f4, %f5; + @%p8 bra BB3_12; + + abs.f32 %f41, %f4; + setp.le.f32 %p9, %f41, 0f7F800000; + @%p9 bra BB3_13; + +BB3_12: + atom.global.add.u32 %r22, [%rd1], 1; + +BB3_13: + ret; } )"; @@ -364,9 +578,9 @@ static StatusOr DeviceCompare(se::Stream* stream, uint64 buffer_size = lhs_typed.ElementCount(); TF_ASSIGN_OR_RETURN(absl::Span compiled_ptx, - se::cuda::CompilePtxOrGetCached( - executor->device_ordinal(), buffer_compare_ptx, - PtxOptsFromConfig(config))); + se::CompileGpuAsmOrGetCached(executor->device_ordinal(), + buffer_compare_ptx, + PtxOptsFromConfig(config))); TF_ASSIGN_OR_RETURN( std::unique_ptr> comparison_kernel, @@ -472,6 +686,9 @@ StatusOr BufferComparator::CompareEqual(se::Stream* stream, case xla::F64: return CompareEqualParameterized( stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); + case xla::S8: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_int8_comparison"); default: return Unimplemented("Unimplemented element type"); } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 139e4204304..0f547111096 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -178,6 +178,13 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + + EXPECT_TRUE(CompareEqualFloatBuffers({200}, {201})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); + EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); + EXPECT_FALSE(CompareEqualFloatBuffers({-128}, {127})); } TEST_F(BufferComparatorTest, TestMultiple) { @@ -231,6 +238,23 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } + + { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {21, 31, 41, 51, 61})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 60301b4de64..2fe359861f8 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 7a7ab6ba05f..c829fc92c87 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -37,7 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/proto/proto_utils.h" -#include "tensorflow/stream_executor/cuda/redzone_allocator.h" +#include "tensorflow/stream_executor/gpu/redzone_allocator.h" namespace xla { namespace gpu { @@ -136,13 +136,13 @@ void PrintPlatformInfo(const se::Stream* stream) { // // `name` is a user-friendly name for the set of redzones being checked, e.g. // "input/output" or "scratch". -StatusOr CheckRedzones(const se::cuda::RedzoneAllocator& allocator, +StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, se::Stream* stream, absl::string_view name, const HloInstruction* instr, AutotuneResult* result) { XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", 2); - using RedzoneCheckStatus = se::cuda::RedzoneAllocator::RedzoneCheckStatus; + using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus; TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check, allocator.CheckRedzones()); if (redzone_check.ok()) { @@ -271,29 +271,29 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( int64 rng_state = 0; - const auto initialize_buffer = [stream, &result_shape, - &rng_state](DeviceMemoryBase buffer) { - InitializeFloatBuffer(stream, result_shape.element_type(), &rng_state, - buffer); + const auto initialize_buffer = [&stream, &rng_state]( + DeviceMemoryBase buffer, + const Shape& buffer_shape) { + InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); }; const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); // Allocate space for the input, filter, and output of the convolution. - se::cuda::RedzoneAllocator input_output_allocator( + se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); std::vector operand_buffers; for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); - initialize_buffer(buffer); + initialize_buffer(buffer, operand->shape()); operand_buffers.push_back(buffer); } TF_ASSIGN_OR_RETURN(auto result_buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(result_shape))); - initialize_buffer(result_buffer); + initialize_buffer(result_buffer, result_shape); TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); @@ -339,7 +339,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( continue; } - se::cuda::RedzoneAllocator scratch_allocator( + se::RedzoneAllocator scratch_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h deleted file mode 100644 index d4e51e86c1b..00000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ - -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// Adds padding to cudnn convolutions to make them run faster on GPUs with -// tensor cores. -// -// - f16 convolutions are padded to have input/output channel dimensions that -// are multiples of 8, so that we can use tensor cores. -// -// - f16 convolutions with 3 input channels and 32 or 64 output channels are -// padded to 4 input channels. There's a special-cased cudnn algorithm just -// for this. -// -// Don't run this pass on GPUs without tensor cores -- it will make them slower! -// -// TODO(jlebar): Also pad dots. -class CudnnConvPadForTensorCores : public HloModulePass { - public: - absl::string_view name() const override { return "cudnn-conv-pad-for-speed"; } - - StatusOr Run(HloModule* module) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index 8596f640fc2..7b23935fbac 100755 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -619,43 +619,85 @@ CudnnConvBackendConfig GetDefaultBackendConfig() { return config; } +// Helper function to create a custom_call instruction to replace the given +// conv instruction +static StatusOr CreateCustomCallHelper(HloInstruction* conv) { + bool match; + Window window; + ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; + HloInstruction* lhs; + + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); + if (match) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count(), conv->metadata()); + } + + std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv); + if (match) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + lhs, conv->mutable_operand(1), window, dnums, + conv->feature_group_count(), conv->metadata()); + } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + if (primitive_util::IsIntegralType( + conv->operand(0)->shape().element_type())) { + // In addition to replacing a convolution instruction with + // a custom call, integer convolutions must have this pattern to match + // CuDNN semantics: + // conv( + // convert(int8_x), convert(int8_y)) + // We transform it to: + // custom_call(int8_x, int8_y, target=cudnnConvolutionForward) + // + // We will error out, if the pattern is not found for integer + // convolution. + const auto is_int8_to_int32_cast = + [](const HloInstruction* instr) -> bool { + return (instr->opcode() == HloOpcode::kConvert && + instr->operand(0)->shape().element_type() == S8 && + instr->shape().element_type() == S32); + }; + HloInstruction* input_convert = conv->mutable_operand(0); + HloInstruction* kernel_convert = conv->mutable_operand(1); + if (conv->shape().element_type() != S32 || + !is_int8_to_int32_cast(input_convert) || + !is_int8_to_int32_cast(kernel_convert)) { + return Unimplemented( + "Integer convolutions for CuDNN must have this pattern: " + "conv(convert(int8_x), " + "convert(int8_y))"); + } + // Bypass the convert for both inputs. + TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape( + 0, input_convert->mutable_operand(0))); + TF_RETURN_IF_ERROR( + conv->parent()->RemoveInstructionAndUnusedOperands(input_convert)); + TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape( + 1, kernel_convert->mutable_operand(0))); + TF_RETURN_IF_ERROR( + conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert)); + } + return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + conv->window(), + conv->convolution_dimension_numbers(), + conv->feature_group_count(), conv->metadata()); + } + + return nullptr; +} + // Tries to rewrite a single convolution into a call to cudnn. StatusOr RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - HloInstruction* custom_call = [&]() -> HloInstruction* { - bool match; - Window window; - ConvolutionDimensionNumbers dnums; - HloInstruction* rhs; - HloInstruction* lhs; - - std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); - if (match) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), - conv->mutable_operand(0), rhs, window, dnums, - conv->feature_group_count(), conv->metadata()); - } - - std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv); - if (match) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), - lhs, conv->mutable_operand(1), window, dnums, - conv->feature_group_count(), conv->metadata()); - } - - // If all else fails, try a forward convolution. - if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), - conv->mutable_operand(0), conv->mutable_operand(1), - conv->window(), - conv->convolution_dimension_numbers(), - conv->feature_group_count(), conv->metadata()); - } - - return nullptr; - }(); - + TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, + CreateCustomCallHelper(conv)); if (custom_call == nullptr) { return false; } @@ -666,8 +708,8 @@ StatusOr RunOnInstruction(HloInstruction* conv) { VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << custom_call->ToString(); - // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out - // the conv result and replace `conv` with it. + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract + // out the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( conv, HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h index d8ec72c27ba..77b57c910c9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h @@ -24,6 +24,14 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. +// For integer convolution, it requires the following pattern: +// conv( +// convert(int8_x), convert(int8_y)) +// We transform it to: +// custom_call(int8_x, int8_y, target=cudnnForwardConvolution) +// Note that this pattern is necessary but not sufficient to map convolutions +// to CuDNN. More patterns will be matched in cudnn_fused_conv_rewriter. + class CudnnConvRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn-conv-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 362d8d13aab..815963bfa9f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -711,6 +711,21 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { 0)); } +// Check that a forward convolution instruction with int8 inputs is not allowed +TEST_F(CudnnConvRewriterTest, TestForwardInt8Convolution) { + const string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = s8[1,2,3,3] parameter(0) + filter = s8[3,3,2,5] parameter(1) + + ROOT conv = s8[1,5,3,3] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + ASSERT_FALSE(CudnnConvRewriter().Run(m.get()).ok()); +} } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index aca7307e0c2..b2cac986761 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -26,8 +26,12 @@ namespace xla { namespace gpu { namespace { -// Describes a matched pattern: +// Describes matched patterns: // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// for floating point types or +// max(0, alpha1 * conv(int8_x, int8_w) + alpha2 * +// * side_input + broadcast(bias)); +// for int8. // Where side_input has the shape of output buffer, and bias is a 1D array with // the dimension of number of output features. struct ConvWithRelu { @@ -39,6 +43,13 @@ struct ConvWithRelu { HloConstantInstruction* alpha_side_input; }; +// The pattern we want to match: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// or +// max(0, alpha1 * conv(int8_x, int8_w) + alpha2 * +// * side_input + broadcast(bias)); +// With its variants involving commute/reassociation of adds, multiplies, and +// max, and omission of alpha1, side_input, alpha2, or bias. absl::optional FindConvWithRelu(HloInstruction* instr) { using match::Add; using match::AddAnyOrder; @@ -50,12 +61,6 @@ absl::optional FindConvWithRelu(HloInstruction* instr) { using match::MultiplyAnyOrder; using match::Op; - // The pattern we want to match: - // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); - // - // With its variants involving commute/reassociation of adds, multiplies, and - // max, and omission of alpha1, side_input, alpha2, or bias. - HloInstruction* relu_input; // Match max(0, relu_input). @@ -149,6 +154,14 @@ absl::optional FindConvWithRelu(HloInstruction* instr) { return absl::nullopt; } + // In order to map to cudnnConvolutionBiasActivationForward for int8, the + // convolution output is float, i.e. conv(int8_x, int8_w) + if (conv->operand(0)->shape().element_type() == xla::S8) { + if (conv->shape().tuple_shapes(0).element_type() != xla::F32) { + return absl::nullopt; + } + } + if (bias_broadcast) { // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. if (bias_broadcast_instr->dimensions().size() != 1) { @@ -174,7 +187,6 @@ StatusOr> TryRewriteToCudnnForwardRelu( auto conv = match.conv; HloComputation* computation = conv->parent(); - PrimitiveType element_type = conv->operand(0)->shape().element_type(); const auto get_alpha_value = [](HloConstantInstruction* instr) -> StatusOr { @@ -204,13 +216,15 @@ StatusOr> TryRewriteToCudnnForwardRelu( auto bias = match.bias; if (!bias) { + PrimitiveType conv_output_type = + conv->shape().tuple_shapes(0).element_type(); auto zero = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + HloInstruction::CreateConstant(LiteralUtil::Zero(conv_output_type))); int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( conv->convolution_dimension_numbers().output_feature_dimension()); bias = computation->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithDescendingLayout(element_type, + ShapeUtil::MakeShapeWithDescendingLayout(conv_output_type, {num_output_feature}), zero, {})); } @@ -242,9 +256,9 @@ StatusOr> TryRewriteToCudnnForwardRelu( new_conv, 0); } -} // namespace - -StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { +// Fuse bias/scaling/ReLU with convolution custom call with floating point +// output +StatusOr RunFuseBiasSideActivation(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { std::vector matches; @@ -277,5 +291,201 @@ StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { return changed; } +// Describes a matched pattern: +// convert_or_clamp(get_tuple_element(custom_call(x,w, ...))); +// where the custom_call targets CuDNN convolution (either pure convolution or +// fused convolution). +struct ConvWithConvertOrClamp { + HloInstruction* convert_or_clamp; + HloInstruction* gte; + HloCustomCallInstruction* conv; +}; + +// The pattern we want to match: +// convert(clamp(broadcast(-128), (get_tuple_element(custom_call(int8_x, +// int8_w, ...)), broadcast(127)); +absl::optional FindConvWithClampAndConvertToInt8( + HloInstruction* instr) { + using match::Broadcast; + using match::Clamp; + using match::Convert; + using match::GetTupleElement; + using match::Op; + + HloInstruction* gte = nullptr; + HloInstruction* conv_instr = nullptr; + auto lower_pattern = Broadcast(match::ConstantScalar(-128)); + auto upper_pattern = Broadcast(match::ConstantScalar(127)); + auto pattern = Convert( + Clamp(lower_pattern, + GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0), + upper_pattern)); + + if (Match(instr, pattern)) { + if (conv_instr->operand(0)->shape().element_type() == xla::S8 && + instr->shape().element_type() == xla::S8) { + HloCustomCallInstruction* conv = + CastOrNull(conv_instr); + return ConvWithConvertOrClamp{instr, gte, conv}; + } + } + return absl::nullopt; +} + +// A help function to rewrite convert_or_clamp_or_other(gte(conv())) +// to gte(conv()). It bypasses convert_or_clamp_or_other +// and set the output data type on gte and conv. +Status RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match) { + auto conv = match.conv; + auto gte = match.gte; + auto convert_or_clamp = match.convert_or_clamp; + + // Change type on conv and gte + auto convert_out_type = convert_or_clamp->shape().element_type(); + conv->mutable_shape()->mutable_tuple_shapes(0)->set_element_type( + convert_out_type); + gte->mutable_shape()->set_element_type(convert_out_type); + + // Remove clamp/convert and so on and just keep + // get_tuple_element(custom_call(x,w, ...)) + TF_RETURN_IF_ERROR(convert_or_clamp->ReplaceAllUsesWithDifferentShape(gte)); + TF_RETURN_IF_ERROR( + conv->parent()->RemoveInstructionAndUnusedOperands(convert_or_clamp)); + return Status::OK(); +} + +Status RewriteForFinalOutput(ConvWithConvertOrClamp match) { + // When the matched clamp has a single user, which is convert, we + // will absorb it, if + // 1. the side_input matches a convert(int8_side_input), or + // 2. there is no side input + const auto is_one_to_one_X_to_Y_cast = [](const HloInstruction* instr, + PrimitiveType X, + PrimitiveType Y) -> bool { + return (instr->opcode() == HloOpcode::kConvert && + instr->shape().element_type() == Y && instr->operand_count() == 1 && + instr->operand(0)->user_count() == 1 && + instr->operand(0)->shape().element_type() == X); + }; + + if (match.conv->operand_count() < 4) { + // Conv input #3 (zero based) is side_input, after x, w, and bias. + // Side input doesn't exist in this case. + TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match)); + } else if (is_one_to_one_X_to_Y_cast(match.conv->operand(3), S8, F32)) { + // If side_input has a convert_float_to_int8, absorb it as well. + auto side_converter = match.conv->mutable_operand(3); + TF_RETURN_IF_ERROR(side_converter->ReplaceAllUsesWithDifferentShape( + side_converter->mutable_operand(0))); + TF_RETURN_IF_ERROR( + side_converter->parent()->RemoveInstructionAndUnusedOperands( + side_converter)); + + TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match)); + } + return Status::OK(); +} + +// Fuse the clamp/convert pattern with the int8 convolution custom call +// (either pure or fused) for int8 output +StatusOr RunFuseClamp(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + for (auto instr : computation->instructions()) { + auto match = FindConvWithClampAndConvertToInt8(instr); + if (match.has_value()) { + matches.push_back(*match); + } + } + for (const ConvWithConvertOrClamp& match : matches) { + TF_RETURN_IF_ERROR(RewriteForFinalOutput(match)); + changed = true; + } + + // Report error for any convolution still having int32 output. + // Although int32 output convolution will trigger other sanity check errors + // later, we want to give specific error message here. + for (auto instr : computation->instructions()) { + if (auto call = DynCast(instr)) { + if ((call->custom_call_target() == kCudnnConvForwardCallTarget || + call->custom_call_target() == + kCudnnConvBiasActivationForwardCallTarget) && + call->shape().tuple_shapes(0).element_type() == xla::S32) { + return Unimplemented( + "Integer convolutions for CuDNN must have float or int8 output. " + "Use convert to cast output to float or the following pattern to " + "int8: " + "clamp(broadcast(-128), conv(int8_x, int8_w, ...), " + "broadcast(127))."); + } + } + } + } + return changed; +} + +// The pattern we want to match: +// convert(get_tuple_element(custom_call())); +absl::optional FindConvWithConvertToFloat( + HloInstruction* instr) { + using match::Convert; + using match::GetTupleElement; + using match::Op; + + HloInstruction* gte = nullptr; + HloInstruction* conv_instr = nullptr; + auto pattern = + Convert(GetTupleElement( + >e, + Op(&conv_instr) + .WithOpcode(HloOpcode::kCustomCall) + .WithCustomCallTarget(kCudnnConvForwardCallTarget), + 0) + .WithShape(match::Shape().WithElementType(xla::S32))) + .WithShape(match::Shape().WithElementType(xla::F32)); + if (Match(instr, pattern)) { + HloCustomCallInstruction* conv = + CastOrNull(conv_instr); + return ConvWithConvertOrClamp{instr, gte, conv}; + } + return absl::nullopt; +} + +// Transform +// convert(GetTupleElement(custom_call(int8_x, int8_w))) +// to +// GetTupleElement(custom_call(int8_x, int8_w)) +StatusOr RunFuseConvertToFloat(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + for (auto instr : computation->instructions()) { + auto match = FindConvWithConvertToFloat(instr); + if (match.has_value()) { + matches.push_back(*match); + } + } + + for (const ConvWithConvertOrClamp& match : matches) { + TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match)); + changed = true; + } + } + return changed; +} +} // namespace + +StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(bool fused_for_convert_to_float, + RunFuseConvertToFloat(module)); + + TF_ASSIGN_OR_RETURN(bool fused_for_bias, RunFuseBiasSideActivation(module)); + + TF_ASSIGN_OR_RETURN(bool fused_for_clamp, RunFuseClamp(module)); + + return fused_for_convert_to_float || fused_for_bias || fused_for_clamp; +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h index 613ed8dbdc3..e3602b70d29 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -22,6 +22,40 @@ limitations under the License. namespace xla { namespace gpu { +// Rewrite the custom call targeting cudnnConvolutionForward to +// cudnnConvolutionBiasActivationForward by fusing applicable point-wise +// operations following forward convolution. This transform must run after +// cudnn_conv_rewriter. +// It is straightforward for floating point convolutions: +// transforming +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)) +// to +// cudnnConvolutionBiasActivationForward(x, w, bias, alpha1, alpha2, side) +// +// Integer convolution requires additional patterns to match CuDNN semantics: +// #1 from +// cast(clamp<-128, 127>(conv(int8_x, int8_w))) +// to +// cudnnConvolutionForward(int8_x, int8_w) +// or #2 from +// cast(conv(int8_x, int8_w)) +// to +// cudnnConvolutionForward(int8_x, int8_w) +// or #3 from +// cast(clamp<-128, 127>(max(0, alpha1 * +// cast(conv(int8_x, int8_w)) + +// alpha2 * cast(int8_side) + +// broadcast(bias))) +// to +// cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, +// alpha2, int8_side) +// or #4 from +// max(0, alpha1 * cast(conv(int8_x, int8_w)) + +// alpha2 * float_side + broadcast(bias)) +// to +// cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, +// alpha2, float_side) + class CudnnFusedConvRewriter : public HloModulePass { public: absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index b621880f639..bd6aa6e715a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" + #include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" @@ -26,7 +30,7 @@ namespace { using ::testing::HasSubstr; using ::testing::Not; -class CudnnFusedConvRewriterTest : public HloTestBase { +class CudnnFusedConvRewriterTest : public GpuCodegenTest { protected: string GetOptimizedHlo(absl::string_view hlo_string) { return backend() @@ -53,6 +57,19 @@ class CudnnFusedConvRewriterTest : public HloTestBase { } } + void TestClamp(absl::string_view pre_hlo_string, + absl::string_view post_hlo_string) { + string alpha_conv_scalar, alpha_side_input_scalar; + string elementwise_type; + + string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); + EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert"))); + EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv")); + EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01})) + << pre_hlo_string; + MatchOptimizedHlo(pre_hlo_string, post_hlo_string); + } + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { for (absl::string_view type : {"f16", "f32", "f64"}) { const string hlo_with_new_type = @@ -349,6 +366,350 @@ TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01})); } +TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) { + // max(0, clamp(conv(x, w)))); for int8 + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = s8[] constant(0) + zeros = s8[1,32,9,9] broadcast(zero), dimensions={} + + input = s8[1,17,9,9] parameter(0) + filter = s8[3,3,17,32] parameter(1) + + inputs32 = s32[1,17,9,9] convert(input) + filters32 = s32[3,3,17,32] convert(filter) + + conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + + lower = s32[] constant(-128) + lowers = s32[1,32,9,9] broadcast(lower), dimensions={} + upper = s32[] constant(127) + uppers = s32[1,32,9,9] broadcast(upper), dimensions={} + + clamp = s32[1,32,9,9] clamp(lowers, conv, uppers) + + convert = s8[1,32,9,9] convert(clamp) + ROOT relu = s8[1,32,9,9] maximum(zeros, convert) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] { + ; CHECK: %custom-call{{(\.[0-9])?}} = (s8[1,32,9,9]{1,3,2,0}, u8[{{[0-9]*}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config= + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToFloat) { + // convert(conv(convert(int8_x), + // convert(int8_w))); + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = s8[1,17,9,9] parameter(0) + filter = s8[3,3,17,32] parameter(1) + + inputs32 = s32[1,17,9,9] convert(input) + filters32 = s32[3,3,17,32] convert(filter) + + conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + + ROOT convert = f32[1,32,9,9] convert(conv) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> f32[1,32,9,9] { + ; CHECK: %custom-call{{(\.[0-9])?}} = (f32[1,32,9,9]{1,3,2,0}, u8[{{[0-9]+}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config= + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) { + // clamp(max(0, conv(x, w)+bias)); for int8 + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,3,3,64] broadcast(zero), dimensions={} + + input = s8[1,3,3,64] parameter(0) + filter = s8[3,3,64,64] parameter(1) + bias = f32[64] parameter(2) + + inputs32 = s32[1,3,3,64] convert(input) + filters32 = s32[3,3,64,64] convert(filter) + + conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + + convfloat = f32[1,3,3,64] convert(conv) + broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3} + add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias) + relu = f32[1,3,3,64] maximum(zeros, add1) + + lower = f32[] constant(-128) + lowers = f32[1,3,3,64] broadcast(lower), dimensions={} + upper = f32[] constant(127) + uppers = f32[1,3,3,64] broadcast(upper), dimensions={} + + clamp = f32[1,3,3,64] clamp(lowers, relu, uppers) + + ROOT convert = s8[1,3,3,64] convert(clamp) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> s8[1,3,3,64] + ; CHECK: %custom-call{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config= + ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%custom-call{{(\.[0-9])?}}), index=0 + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToFloat) { + // max(0, convert(conv(int8_x), + // conv(int8_w))+float_bias)); int8 to float via bias. + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,3,3,64] broadcast(zero), dimensions={} + + input = s8[1,3,3,64] parameter(0) + filter = s8[3,3,64,64] parameter(1) + bias = f32[64] parameter(2) + + inputs32 = s32[1,3,3,64] convert(input) + filters32 = s32[3,3,64,64] convert(filter) + + conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + + convfloat = f32[1,3,3,64] convert(conv) + broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3} + add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias) + ROOT relu = f32[1,3,3,64] maximum(zeros, add1) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] { + ; CHECK: %custom-call{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config= + ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element(%custom-call{{(\.[0-9])?}}), index=0 + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, + TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) { + // clamp(max(0, alpha_conv * conv(x, w) + alpha_side * + // convert(int8_side_input) + bias)); for int8 + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = f32[] constant(0.999994934) + alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = f32[] constant(0.899994934) + alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = s8[1,3,3,64] parameter(0) + filter = s8[3,3,64,64] parameter(1) + side_input = s8[1,3,3,64] parameter(2) + bias = f32[64] parameter(3) + + inputs32 = s32[1,3,3,64] convert(input) + filters32 = s32[3,3,64,64] convert(filter) + side_input_f32 = f32[1,3,3,64] convert(side_input) + + conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + + convfloat = f32[1,3,3,64] convert(conv) + scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv) + scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input) + broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3} + add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = f32[1,3,3,64] add(add1, scaled_side_input) + relu = f32[1,3,3,64] maximum(zeros, add2) + + lower = f32[] constant(-128) + lowers = f32[1,3,3,64] broadcast(lower), dimensions={} + upper = f32[] constant(127) + uppers = f32[1,3,3,64] broadcast(upper), dimensions={} + + clamp = f32[1,3,3,64] clamp(lowers, relu, uppers) + + ROOT convert = s8[1,3,3,64] convert(clamp) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] { + ; CHECK: %custom-call{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config= + ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%custom-call{{(\.[0-9])?}}), index=0 + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, + TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) { + // From: + // convert(clamp(max(0, alpha_conv * conv(x, w) + alpha_side * + // float_side_input + bias))); To: convert(clamp(conv(int8_x, int8_w, + // float_alpha_side, float_side_input, float_bias))); + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = f32[] constant(0.999994934) + alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = f32[] constant(0.899994934) + alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = s8[1,3,3,64] parameter(0) + filter = s8[3,3,64,64] parameter(1) + side_input = f32[1,3,3,64] parameter(2) + bias = f32[64] parameter(3) + + inputs32 = s32[1,3,3,64] convert(input) + filters32 = s32[3,3,64,64] convert(filter) + + conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + + convfloat = f32[1,3,3,64] convert(conv) + scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv) + scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3} + add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = f32[1,3,3,64] add(add1, scaled_side_input) + relu = f32[1,3,3,64] maximum(zeros, add2) + + lower = f32[] constant(-128) + lowers = f32[1,3,3,64] broadcast(lower), dimensions={} + upper = f32[] constant(127) + uppers = f32[1,3,3,64] broadcast(upper), dimensions={} + + clamp = f32[1,3,3,64] clamp(lowers, relu, uppers) + + ROOT convert = s8[1,3,3,64] convert(clamp) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: f32[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] { + ; CHECK: %custom-call{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config= + ; CHECK: ROOT %fusion = s8[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, + TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) { + // From: + // clamp(max(0, alpha_conv * conv(x, w) + alpha_side * + // convert(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w, + // float_alpha_side, convert(int8_side_input), float_bias)); + TestClamp( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = f32[] constant(0.999994934) + alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = f32[] constant(0.899994934) + alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = s8[1,3,3,64] parameter(0) + filter = s8[3,3,64,64] parameter(1) + side_input = s8[1,3,3,64] parameter(2) + bias = f32[64] parameter(3) + + inputs32 = s32[1,3,3,64] convert(input) + filters32 = s32[3,3,64,64] convert(filter) + side_input_f32 = f32[1,3,3,64] convert(side_input) + + conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + + convfloat = f32[1,3,3,64] convert(conv) + scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv) + scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input) + broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3} + add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = f32[1,3,3,64] add(add1, scaled_side_input) + relu = f32[1,3,3,64] maximum(zeros, add2) + + lower = f32[] constant(-128) + lowers = f32[1,3,3,64] broadcast(lower), dimensions={} + upper = f32[] constant(127) + uppers = f32[1,3,3,64] broadcast(upper), dimensions={} + + ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers) + })", + // post_hlo + R"( + ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> f32[1,3,3,64] { + ; CHECK: %side_input_f32 = f32[1,3,3,64]{3,2,1,0} convert(%side_input) + ; CHECK: %custom-call{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config= + ; CHECK: ROOT %fusion = f32[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) { + // Check that integer convolution without clamp to int8 is not allowed. + // convert(custom_call(int32_x, int32_w, + // cudnnConvolutionForward)) + const string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] { + zero = s8[] constant(0) + zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={} + input = s8[1,17,9,9]{3,2,1,0} parameter(0) + filter = s8[3,3,17,32]{3,2,1,0} parameter(1) + custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}" + get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0 + convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element) + ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert) + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok()); +} + +TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) { + // Although bias and so on are fused with forward convolution, + // it is still not allowed if the output is not clampped/converted to int8 + // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for int8 + + const string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] { + zero = s8[] constant(0) + zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={} + input = s8[1,17,9,9]{3,2,1,0} parameter(0) + filter = s8[3,3,17,32]{3,2,1,0} parameter(1) + custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}" + get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0 + convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element) + ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert) + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc similarity index 52% rename from tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc index 958e0b9c6e7..17c02b64db5 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -70,21 +70,26 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); } -// Modifies the given convolution to have the given LHS/RHS/result shapes. +// Modifies the given convolution to have the given input and result shapes. static Status PadConv(HloCustomCallInstruction* conv, - const Shape& new_lhs_shape, const Shape& new_rhs_shape, + absl::Span new_input_shapes, const Shape& new_result_shape) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " "before CudnnConvAlgorithmPicker."; - - auto* lhs = conv->mutable_operand(0); - auto* rhs = conv->mutable_operand(1); - auto* new_lhs = PadInstruction(lhs, new_lhs_shape); - auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + std::vector new_operands; + new_operands.reserve(conv->operand_count()); + for (int i = 0; i < conv->operand_count(); ++i) { + new_operands.push_back( + PadInstruction(conv->mutable_operand(i), new_input_shapes[i])); + } const Shape& result_shape = conv->shape().tuple_shapes(0); - CHECK(new_lhs != lhs || new_rhs != rhs) - << "We should have had to pad either LHS or RHS."; + + bool changed = false; + for (int i = 0; i < conv->operand_count(); ++i) { + changed |= (new_operands[i] != conv->mutable_operand(i)); + } + CHECK(changed) << "We should have had to pad at least one input operand."; auto add = [&](std::unique_ptr new_instr) { return conv->parent()->AddInstruction(std::move(new_instr)); @@ -93,10 +98,10 @@ static Status PadConv(HloCustomCallInstruction* conv, Shape new_conv_shape = ShapeUtil::MakeTupleShape( {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); auto* new_conv = - add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + add(conv->CloneWithNewOperands(new_conv_shape, new_operands)); - // Slice the new conv result if necessary, keeping in mind that new_conv has - // tuple shape (new_result_shape, u8[0]). + // Slice the new conv result if necessary, keeping in mind that new_conv + // has tuple shape (new_result_shape, u8[0]). if (!ShapeUtil::Equal(result_shape, new_result_shape)) { std::vector start_indices(result_shape.dimensions_size(), 0); std::vector end_indices(result_shape.dimensions().begin(), @@ -118,7 +123,61 @@ static Status PadConv(HloCustomCallInstruction* conv, return conv->parent()->ReplaceInstruction(conv, new_conv); } -static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(Cast(instr)); + } + } + return convs; +} + +// This is the main function of the transform. It runs on a given custom call +// nodes to cuDNN convolution, calls resolve_pad_shapes to resolve +// the desired input/output feature map shapes, and adds necessary padding and +// slicing nodes around them. +// +// resolve_pad_shapes points to a function. It takes conv, a custom call +// instruction to cuDNN convolution that may need padding to figure out the +// desired padded input and output tensor shapes and store the desired +// shapes in new_input_shapes and new_input_shapes. Notice that +// new_input_shapes is a vector for multiple input tesnsors. This function +// shall return true, if padding is necessary or false otherwise in addition to +// status. +static StatusOr ResolveAndPad( + HloCustomCallInstruction* conv, + StatusOr (*resolve_pad_shapes)(HloCustomCallInstruction* conv, + std::vector* new_input_shapes, + Shape* new_result_shape)) { + std::vector new_input_shapes; + Shape new_result_shape; + TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes, + &new_result_shape)); + if (result) { + TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape)); + return true; + } + return false; +} + +// Adds padding to cudnn convolutions to make them run faster on GPUs with +// tensor cores. +// +// - f16 convolutions are padded to have input/output channel dimensions that +// are multiples of 8, so that we can use tensor cores. +// +// - f16 convolutions with 3 input channels and 32 or 64 output channels are +// padded to 4 input channels. There's a special-cased cudnn algorithm just +// for this. +// +// Don't run this pass on GPUs without tensor cores -- it will make them slower! +// +// TODO(jlebar): Also pad dots. +static StatusOr TryResolvePadedShapesForTensorCore( + HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, + Shape* new_result_shape_ptr) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); @@ -138,7 +197,8 @@ static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { Shape new_lhs_shape = lhs->shape(); Shape new_rhs_shape = rhs->shape(); - Shape new_result_shape = conv->shape().tuple_shapes(0); + Shape& new_result_shape = *new_result_shape_ptr; + new_result_shape = conv->shape().tuple_shapes(0); // new_{input,filter_output}_shape points to the appropriate one of // new_{lhs,rhs,result}_shape. @@ -211,29 +271,135 @@ static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { return false; } - // OK, let's do the transformation! - TF_RETURN_IF_ERROR( - PadConv(conv, new_lhs_shape, new_rhs_shape, new_result_shape)); + new_input_shapes_ptr->push_back(new_lhs_shape); + new_input_shapes_ptr->push_back(new_rhs_shape); return true; } -static std::vector GetRelevantConvs( - HloComputation* comp) { - std::vector convs; - for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(Cast(instr)); - } +// Adds padding to cudnn integer convolutions to make input and output feature +// maps multiple of 4 +static StatusOr TryResolvePadedShapesForIntegerConvolution( + HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, + Shape* new_result_shape_ptr) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + const Shape& input_shape = conv->operand(0)->shape(); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + // Integer convolution only + if (!primitive_util::IsIntegralType(input_shape.element_type())) { + return false; } - return convs; + + // kForward and kForwardActivation only + if (kind != CudnnConvKind::kForward && + kind != CudnnConvKind::kForwardActivation) { + return false; + } + + const auto& dnums = conv->convolution_dimension_numbers(); + std::vector& new_input_shapes = *new_input_shapes_ptr; + for (auto operand : conv->operands()) { + new_input_shapes.push_back(operand->shape()); + } + Shape& new_result_shape = *new_result_shape_ptr; + new_result_shape = conv->shape().tuple_shapes(0); + + // Pad the features to multiples of 4 and check that + // the conv buffers size changes for debugging purpose. + { + auto pad_dim = [](Shape* s, int64 dim) { + s->set_dimensions(dim, RoundUpToNearest(s->dimensions(dim), 4)); + }; + + switch (kind) { + case CudnnConvKind::kForward: + CHECK_EQ(new_input_shapes.size(), 2); + pad_dim(&new_input_shapes[0], + dnums.input_feature_dimension()); // Input feature maps + pad_dim(&new_input_shapes[1], + dnums.kernel_input_feature_dimension()); // Kernel for the + // input feature maps + pad_dim( + &new_input_shapes[1], + dnums.kernel_output_feature_dimension()); // Kernel for the output + // feature maps + pad_dim(&new_result_shape, + dnums.output_feature_dimension()); // Output feature maps + break; + case CudnnConvKind::kForwardActivation: + CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4); + pad_dim(&new_input_shapes[0], + dnums.input_feature_dimension()); // Input feature maps + pad_dim(&new_input_shapes[1], + dnums.kernel_input_feature_dimension()); // Kernel for the + // input feature maps + pad_dim( + &new_input_shapes[1], + dnums.kernel_output_feature_dimension()); // Kernel for the output + // feature maps + pad_dim(&new_input_shapes[2], 0); // Bias + if (new_input_shapes.size() == 4) { + pad_dim(&new_input_shapes[3], + dnums.output_feature_dimension()); // Optional side input + } + pad_dim(&new_result_shape, + dnums.output_feature_dimension()); // Output feature maps + break; + default: + CHECK(false); + } + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return; + } + VLOG(3) + << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + }; + + for (int64 i = 0; i < conv->operand_count(); ++i) { + check_size_increase(conv->operand(i)->shape(), new_input_shapes[i]); + } + check_size_increase(result_shape, new_result_shape); + } + + bool changed = false; + for (int64 i = 0; i < conv->operand_count(); ++i) { + changed |= + !ShapeUtil::Equal(conv->operand(i)->shape(), new_input_shapes[i]); + } + if (!changed) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + } + + return changed; } -StatusOr CudnnConvPadForTensorCores::Run(HloModule* module) { +StatusOr CudnnPadForConvolutions::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { - TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(conv)); - changed |= result; + TF_ASSIGN_OR_RETURN( + bool local_changed, + ResolveAndPad(conv, TryResolvePadedShapesForIntegerConvolution)); + changed |= local_changed; + } + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { + if (is_volta_or_later_) { + TF_ASSIGN_OR_RETURN( + bool local_changed, + ResolveAndPad(conv, TryResolvePadedShapesForTensorCore)); + changed |= local_changed; + } } } return changed; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h new file mode 100644 index 00000000000..b065c6e4bd4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +// Two zero-paddings for CuDNN thunking are done in this transform: padding for +// tensor cores and padding for integer convolutions. This transform also +// add slice instruction to remove unnecessary output features. +class CudnnPadForConvolutions : public HloModulePass { + public: + explicit CudnnPadForConvolutions(bool is_volta_or_later) + : is_volta_or_later_(is_volta_or_later) {} + absl::string_view name() const override { + return "cudnn_pad_for_convolutions"; + } + // Run PadForConvolutions on the given module and return if any change is made + StatusOr Run(HloModule* module) override; + + private: + const bool is_volta_or_later_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions_test.cc similarity index 60% rename from tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions_test.cc index af9303a5b76..3d0780aedd8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -29,9 +29,9 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvPadForTensorCoresTest : public HloTestBase {}; +class CudnnPadForConvolutionsTest : public HloTestBase {}; -TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -43,7 +43,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { custom_call_target="__cudnn$convForward" })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); @@ -56,7 +56,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); } -TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvOutputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -68,7 +68,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { custom_call_target="__cudnn$convBackwardInput" })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, op::Pad(op::Parameter(0), _), @@ -79,7 +79,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); } -TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvOutputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -91,7 +91,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { custom_call_target="__cudnn$convForward" })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvForwardCallTarget, op::Parameter(0), @@ -99,7 +99,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { _)); } -TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -112,7 +112,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( @@ -121,7 +121,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { _))); } -TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -134,7 +134,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( @@ -143,7 +143,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { _))); } -TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { +TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvOutputChannels) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -156,7 +156,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( @@ -165,7 +165,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { _))); } -TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { +TEST_F(CudnnPadForConvolutionsTest, PadInputFeatures3To4) { auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule @@ -177,7 +177,7 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { custom_call_target="__cudnn$convForward" })") .ValueOrDie(); - EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); auto* root = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); @@ -190,6 +190,78 @@ TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { ShapeUtil::MakeShape(F16, {2, 2, 4, 32}))); } +TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = s8[10,20,30,41] parameter(0) + filter = s8[2,2,41,40] parameter(1) + ROOT result = (f32[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + + SCOPED_TRACE(module->ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + + ShapeUtil::MakeShape(S8, {10, 20, 30, 44}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(S8, {2, 2, 44, 40}))); +} + +TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = s8[10,20,30,40] parameter(0) + filter = s8[2,2,40,41] parameter(1) + ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(CudnnPadForConvolutionsTest, + PadIntFusedForwardConvInputAndOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule Test + + ENTRY %Test (input: s8[1,3,3,2], filter: s8[3,3,2,5], side_input: s8[1,3,3,5], bias: s8[5]) -> f32[1,3,3,5] { + %input = s8[1,3,3,2]{3,2,1,0} parameter(0) + %filter = s8[3,3,2,5]{3,2,1,0} parameter(1) + %bias = s8[5]{0} parameter(3) + %convert = f32[5]{0} convert(s8[5]{0} %bias) + %side_input = f32[1,3,3,5]{3,2,1,0} parameter(2) + %custom-call.1 = (f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,3,3,2]{3,2,1,0} %input, s8[3,3,2,5]{3,2,1,0} %filter, f32[5]{0} %convert, f32[1,3,3,5]{3,2,1,0} %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"activationMode\":\"2\",\"convResultScale\":1,\"sideInputScale\":1}" + ROOT %get-tuple-element.1 = f32[1,3,3,5]{3,2,1,0} get-tuple-element((f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) %custom-call.1), index=0 + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnPadForConvolutions(true).Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBiasActivationForwardCallTarget, + op::Pad(op::Parameter(0), _), op::Pad(op::Parameter(1), _), + op::Pad(op::Convert(op::Parameter(3)), _), + op::Pad(op::Parameter(2), _)))), + _))); +} } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc index 4103a720c98..b18170b00e4 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc @@ -169,7 +169,8 @@ StatusOr CusolverContext::PotrfBufferSize(PrimitiveType type, } #define POTRF_INSTANCE(T, type_prefix) \ - Status CusolverContext::Potrf( \ + template <> \ + Status CusolverContext::Potrf( \ se::blas::UpperLower uplo, int n, se::DeviceMemory A, int lda, \ se::DeviceMemory lapack_info, se::DeviceMemory workspace) { \ return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \ diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h index c3d075c47c7..dfe55188b18 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h @@ -18,8 +18,10 @@ limitations under the License. #include -#include "third_party/gpus/cuda/include/cublas_v2.h" +#if !TENSORFLOW_USE_ROCM #include "third_party/gpus/cuda/include/cusolverDn.h" +#endif + #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -30,6 +32,8 @@ limitations under the License. namespace xla { namespace gpu { +#if !TENSORFLOW_USE_ROCM + class CusolverContext { public: // stream may be nullptr, in which case the context can only be used for @@ -43,26 +47,17 @@ class CusolverContext { CusolverContext& operator=(const CusolverContext&) = delete; CusolverContext& operator=(CusolverContext&&); - se::Stream* stream() const { return stream_; } - cusolverDnHandle_t handle() const { return handle_; } - // Computes the Cholesky factorization A = L * L^T for a single matrix. // Returns Status::OK() if the kernel was launched successfully. See: // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf - Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + template ::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value>> + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, int lda, se::DeviceMemory dev_lapack_info, - se::DeviceMemory workspace); - Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, - int lda, se::DeviceMemory dev_lapack_info, - se::DeviceMemory workspace); - Status Potrf(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, int lda, - se::DeviceMemory dev_lapack_info, - se::DeviceMemory> workspace); - Status Potrf(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, int lda, - se::DeviceMemory dev_lapack_info, - se::DeviceMemory> workspace); + se::DeviceMemory workspace); // Returns the size of the `workspace` required by Potrf, in number of // elements of `type`. @@ -72,10 +67,42 @@ class CusolverContext { private: CusolverContext(se::Stream* stream, cusolverDnHandle_t handle); + cusolverDnHandle_t handle() const { return handle_; } + se::Stream* stream_ = nullptr; cusolverDnHandle_t handle_ = nullptr; }; +#else + +typedef void* cusolverDnHandle_t; + +// TODO(cheshire): Remove this hack once we have ROCM implementation. +class CusolverContext { + public: + static StatusOr Create(se::Stream* stream) { + LOG(FATAL) << "Unimplemented"; + } + + template ::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value>> + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + int lda, se::DeviceMemory dev_lapack_info, + se::DeviceMemory workspace) { + LOG(FATAL) << "Unimplemented"; + } + + StatusOr PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo, + int n, int lda) { + LOG(FATAL) << "Unimplemented"; + } +}; + +#endif + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 65673106391..85571804315 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" #include "absl/strings/str_format.h" -#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/gpu/gpu_stream.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 98d8d00b62c..21a0ffab5e0 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -30,9 +30,9 @@ limitations under the License. #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/stream_executor/blas.h" -#include "tensorflow/stream_executor/cuda/redzone_allocator.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/gpu/redzone_allocator.h" namespace xla { namespace gpu { @@ -59,8 +59,8 @@ static StatusOr> DoUncachedGemmAutotune( const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - const se::cuda::RedzoneAllocator& allocator, - const BufferComparator& comparator, bool crash_on_checking_failure) { + const se::RedzoneAllocator& allocator, const BufferComparator& comparator, + bool crash_on_checking_failure) { if (!stream->parent()->SynchronizeAllActivity()) { return InternalError("Failed to synchronize GPU for autotuning."); } @@ -81,8 +81,8 @@ static StatusOr> DoUncachedGemmAutotune( // the bias parameter. if (backend_config.beta() != 0) { int64 rng_state = 0; - InitializeFloatBuffer(stream, gemm->shape().element_type(), &rng_state, - output_buffer); + InitializeBuffer(stream, gemm->shape().element_type(), &rng_state, + output_buffer); } se::blas::ProfileResult profile_result; @@ -113,7 +113,7 @@ static StatusOr> DoUncachedGemmAutotune( absl::Milliseconds(profile_result.elapsed_time_in_ms())); TF_ASSIGN_OR_RETURN( - se::cuda::RedzoneAllocator::RedzoneCheckStatus rz_check_status, + se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, allocator.CheckRedzones()); if (!rz_check_status.ok()) { result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); @@ -188,7 +188,7 @@ static StatusOr> DoGemmAutotune( const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - bool crash_on_checking_failure, const se::cuda::RedzoneAllocator& allocator, + bool crash_on_checking_failure, const se::RedzoneAllocator& allocator, const BufferComparator& comparator) { // Don't run autotuning concurrently on the same GPU. tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent()); @@ -253,7 +253,7 @@ static StatusOr RunOnInstruction(HloInstruction* instr, }(); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - se::cuda::RedzoneAllocator input_output_allocator( + se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); BufferComparator comparator(instr->shape(), hlo_module_config); @@ -264,8 +264,7 @@ static StatusOr RunOnInstruction(HloInstruction* instr, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(op->shape()))); - InitializeFloatBuffer(stream, op->shape().element_type(), &rng_state, - buffer); + InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); return buffer; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 40ccf7a820b..95e21a84f29 100755 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -231,6 +231,9 @@ Status GpuCompiler::OptimizeHloModule( // run, meaning, the pipeline that contains layout assignment cannot contain // a layout-sensitive verifier! HloPassPipeline pipeline("layout assignment"); + // Layout assignment uses alias analysis, which requires the call graph to + // be flattened. + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout, stream_exec); @@ -306,7 +309,6 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); if (hlo_module->config().alias_passthrough_params()) { pipeline.AddPass(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 97fa275a2e7..3d307cc8993 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -162,8 +162,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, return false; } // The elementwise output shapes must be the same (including layout). - // TODO(tjoerg): Further relax the constraint. The datatype does not matter. - return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1), + return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1), get_loop_shape(instr_2)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index dc4e54c74d2..ae31b10deb3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -465,7 +465,43 @@ TEST_F(GpuFusibleTest, const HloInstruction* fusion_1 = module->entry_computation()->root_instruction()->operand(0)->operand(0); const HloInstruction* fusion_2 = - module->entry_computation()->root_instruction()->operand(1)->operand(0); + module->entry_computation()->root_instruction()->operand(2); + EXPECT_NE(fusion_1, fusion_2); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_DifferentElementType) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2), dimensions={} + add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast) + ROOT convert = s32[8,1,5,16,1,1]{5,4,3,2,1,0} convert(add) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = s32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(2); + EXPECT_NE(fusion_1, fusion_2); EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 550f4662b55..75c9d93c63b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -58,6 +58,12 @@ HeuristicLayoutAssignment(const HloInstruction* instr, std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, DataLayout::kBatchYXDepth); + // Integer convolution must use NHWC. + if (primitive_util::IsIntegralType( + instr->operand(0)->shape().element_type())) { + return kAllNHWC; + } + const DebugOptions& debug_options = instr->GetModule()->config().debug_options(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc index e9a6e8d14d9..bb85c509d18 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc @@ -26,20 +26,7 @@ namespace gpu { // MSVC requires the extra const. Without, it reports an // "error C2131: expression did not evaluate to a constant". constexpr const absl::string_view kDefaultBlacklist = R"pb( - entries { - hlo: "(f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\"" - cc { major: 7 } - cudnn_version { major: 7 minor: 6 patch: 2 } - blas_version: "10201" - algos { id: 1 tensor_ops: true } - } - entries { - hlo: "(f16[7,7,4,64]{2,1,0,3}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[256,112,112,64]{3,2,1,0}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config=\"{conv_result_scale:1}\"" - cc { major: 7 } - cudnn_version { major: 7 minor: 6 patch: 2 } - blas_version: "10201" - algos { id: 1 tensor_ops: true } - })pb"; +)pb"; absl::Span GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1c617a07372..ea238a6db02 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1857,28 +1857,42 @@ Status IrEmitterUnnested::EmitTargetElementLoop( namespace { std::tuple GetStartOffsetAndStepForX( - int64 tile_size_x, int64 num_threads_x, - const KernelMappingScheme& mapping_scheme, llvm::IRBuilder<>* builder, - llvm::Value* x, llvm::Type* index_ty) { + const KernelMappingScheme& mapping_scheme, llvm::IRBuilder<>* b, + llvm::Value* x, const IrEmitterUnnested::ConstantGenerator& constant) { llvm::Value* start_offset_x; int64 step_x; if (mapping_scheme.DilatedX()) { start_offset_x = x; - step_x = num_threads_x; + step_x = mapping_scheme.GetNumberOfThreadsForDimensionX(); } else { - start_offset_x = builder->CreateMul( - x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x)); + start_offset_x = b->CreateMul( + x, constant(mapping_scheme.GetTileSizeForDimensionX() / + mapping_scheme.GetNumberOfThreadsForDimensionX())); step_x = 1; } return std::make_tuple(start_offset_x, step_x); } // Emits code for writing into a tile which fits fully into the output buffer. +// +// Pseudocode: +// +// for (y_idx = 0; y_idx < tile_size_y; y_idx += num_threads_y) { +// for (j = 0; j < tile_size_x / num_threads_x; j++) { +// y_pos = y + y_idx; +// if (dilated) +// x_pos = x + j * num_threads_x +// else +// x_pos = x * (tile_size_x / num_threads_x) + j +// +// EmitElementary(y_pos, x_pos); +// } +// } void EmitFullElementalTile( const KernelMappingScheme& mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Type* index_ty, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y, + llvm::Value* x, const IrEmitterUnnested::ConstantGenerator& constant, const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme.GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme.GetNumberOfThreadsForDimensionY(); @@ -1887,26 +1901,23 @@ void EmitFullElementalTile( llvm::Value* start_offset_x; int64 step_x; - std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( - tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + std::tie(start_offset_x, step_x) = + GetStartOffsetAndStepForX(mapping_scheme, b, x, constant); IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); - ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), - /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), - /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), - [&](llvm::Value* y_indvar) { + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, b) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, b); + ksl->For(loop_name + "_y", /*start=*/constant(0), + /*end=*/constant(tile_size_y), + /*step=*/constant(num_threads_y), [&](llvm::Value* y_indvar) { IrArray::Index source_idx_y = source_idx.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + y_indvar, KernelMappingScheme::DimY, b); + llvm::Value* y_loc = b->CreateAdd(y_indvar, y); for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim( - llvm::ConstantInt::get(index_ty, j * step_x), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = builder->CreateAdd( - llvm::ConstantInt::get(index_ty, j * step_x), - start_offset_x); + constant(j * step_x), KernelMappingScheme::DimX, b); + llvm::Value* x_loc = + b->CreateAdd(constant(j * step_x), start_offset_x); emit_elem_function(source_idx_y_x, y_loc, x_loc, j); } }); @@ -1914,12 +1925,30 @@ void EmitFullElementalTile( // Emits code for writing into a tile which does not fit fully into the output // buffer. +// +// Pseudocode: +// +// for (j = 0; j < tile_size_x / num_threads_x; j++) { +// if (dilated) +// x_pos = x + j * num_threads_x +// else +// x_pos = x * (tile_size_x / num_threads_x) + j +// +// if (x_pos < tile_width) { +// for (y_indvar = 0; y_indvar < tile_height_bound; y_indvar += +// num_threads_y) { +// if (y_indvar < tile_height) { +// EmitElementary(y + y_indevar, x); +// } +// } +// } +// } void EmitPartialElementalTile( const KernelMappingScheme& mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - llvm::Type* index_ty, + const IrEmitterUnnested::ConstantGenerator& constant, const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme.GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme.GetNumberOfThreadsForDimensionY(); @@ -1927,45 +1956,36 @@ void EmitPartialElementalTile( llvm::Value* start_offset_x; int64 step_x; - std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( - tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + std::tie(start_offset_x, step_x) = + GetStartOffsetAndStepForX(mapping_scheme, b, x, constant); IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); - for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { - IrArray::Index source_idx_x = - source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = builder->CreateAdd( - llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x); + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, b) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, b); - ksl->If( - loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), - [&] { - // tile_height_bound = - // ceil(tile_height / num_threads_y) * num_threads_y - llvm::Value* ceiling_of_ratio = builder->CreateUDiv( - builder->CreateAdd(tile_height, llvm::ConstantInt::get( - index_ty, num_threads_y - 1)), - llvm::ConstantInt::get(index_ty, num_threads_y)); - llvm::Value* tile_height_bound = builder->CreateMul( - ceiling_of_ratio, - llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->For( - loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), - /*end=*/tile_height_bound, - /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), - [&](llvm::Value* y_indvar) { - llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->If(loop_name + "_y_in_tile", - builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function( - source_idx_x.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder), - y_loc, x_loc, j); - }); - }); - }); + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_x = source_idx.AddOffsetToDim( + constant(j * step_x), KernelMappingScheme::DimX, b); + llvm::Value* x_loc = b->CreateAdd(constant(j * step_x), start_offset_x); + ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width), [&] { + llvm::Value* ceiling_of_ratio = + b->CreateUDiv(b->CreateAdd(tile_height, constant(num_threads_y - 1)), + constant(num_threads_y)); + llvm::Value* tile_height_bound = + b->CreateMul(ceiling_of_ratio, constant(num_threads_y)); + ksl->For(loop_name, + /*start=*/constant(0), + /*end=*/tile_height_bound, + /*step=*/constant(num_threads_y), [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = b->CreateAdd(y_indvar, y); + ksl->If(loop_name + "_y_in_tile", + b->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx_x.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, b), + y_loc, x_loc, j); + }); + }); + }); } } @@ -1977,31 +1997,39 @@ void EmitPartialElementalTile( // about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits // bounds check to ensure that each processed element is within the boundary // defined by `tile_width` and `tile_height`. +// +// Pseudocode: +// +// if (tile_size_x == tile_width && tile_size_y == tile_height) { +// EmitFullElementalTile(); +// } else { +// EmitPartialElementalTile(); +// } void EmitTiledElementalCodeWithBoundsCheck( const KernelMappingScheme& mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme.GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme.GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_ty, val); + }; ksl->If( loop_name + "_full_tile", - builder->CreateAnd( - builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), - tile_width), - builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), - tile_height)), + b->CreateAnd(b->CreateICmpEQ(constant(tile_size_x), tile_width), + b->CreateICmpEQ(constant(tile_size_y), tile_height)), [&] { EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, index_ty, emit_elem_function); + b, y, x, constant, emit_elem_function); }, [&] { EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, - ksl, builder, y, x, tile_height, tile_width, - index_ty, emit_elem_function); + ksl, b, y, x, tile_height, tile_width, + constant, emit_elem_function); }); } } // namespace @@ -2329,9 +2357,10 @@ void IrEmitterUnnested::EmitTileElementForReduction( : unnested_hlo; // Record the untransposed output linear address for the reduction. int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num; - Store(GetUntransposedOutputLinearAddress(&b_, index, reduction_info), - InBoundsGEP(reduction_info.GetCurrentOutputLinearIndexAddress(), - {b_.getInt32(partial_result_index)})); + b_.CreateStore( + GetUntransposedOutputLinearAddress(&b_, index, reduction_info), + InBoundsGEP(reduction_info.GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(partial_result_index)})); if (!reduction_info.IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); @@ -2381,23 +2410,23 @@ void IrEmitterUnnested::EmitTileElementForReduction( int num_partial_results = GetNumberOfPartialResults(reduction_info); auto index_without_linear = IrArray::Index( input_index.multidim(), reduction_operand_shape, input_index.GetType()); - absl::Span partial_reduction_result_addresses = - reduction_info.GetPartialResultAddresses(); - absl::Span reduction_input_addresses = - reduction_info.GetReductionInputAddresses(); + // Emit code to generate the input and perform the reduction computation for // each reduction instruction. for (int i = 0; i != reducers.size(); ++i) { + llvm::AllocaInst* input_address = + reduction_info.GetReductionInputAddresses()[i]; + llvm::AllocaInst* partial_reduction_result_address = + reduction_info.GetPartialResultAddresses()[i]; llvm::Value* const input_ir_value = input_gens[i](num_partial_results > 1 ? index_without_linear : input_index) .ValueOrDie(); - Store(input_ir_value, reduction_input_addresses[i]); - llvm::Value* partial_result_address = - InBoundsGEP(partial_reduction_result_addresses[i], - {b_.getInt32(partial_result_index)}); + Store(input_ir_value, input_address); + llvm::Value* partial_result_address = InBoundsGEP( + partial_reduction_result_address, {b_.getInt32(partial_result_index)}); TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {partial_result_address, reduction_input_addresses[i]}, + *reducers[i], {partial_result_address, input_address}, partial_result_address)); } @@ -2408,46 +2437,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( /*use_linear_index=*/num_partial_results == 1, extra_output_gens)); } -// Emits tiles for a given dimension. -static void EmitTilesForBlockDim( - const KernelMappingScheme& mapping_scheme, KernelSupportLibrary* ksl, - llvm::Type* index_ty, const string& loop_name, - const IrArray::Index& starting_tile, int dim_id, llvm::IRBuilder<>* b_, - const std::function - emit_next_block_dim) { - absl::Span dims_in_tile = mapping_scheme.GetDimensionsInTiles(); - absl::Span dims_in_block = - mapping_scheme.GetDimensionsInBlocks(); - absl::Span block_sizes = mapping_scheme.GetBlockSizes(); - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - if (block_sizes[dim_id] == 1) { - emit_next_block_dim(starting_tile); - } else { - llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id]; - llvm::Value* block_size_for_dim = index_typed_constant(block_sizes[dim_id]); - llvm::Value* block_id_for_dim = - b_->CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); - llvm::Value* last_block_for_dim = - index_typed_constant(dims_in_block[dim_id] - 1); - llvm::Value* last_block_size_for_dim = - index_typed_constant(dims_in_tile[dim_id] - - (dims_in_block[dim_id] - 1) * block_sizes[dim_id]); - llvm::Value* num_tiles_in_block = - b_->CreateSelect(b_->CreateICmpEQ(last_block_for_dim, block_id_for_dim), - last_block_size_for_dim, block_size_for_dim); - ksl->For(loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, b_); - emit_next_block_dim(tile_index); - }); - } -} - // Returns the index for the first element in the tile with the given tile // index. static IrArray::Index GetElementIndexForTileOrigin( @@ -2467,79 +2456,20 @@ static IrArray::Index GetElementIndexForTileOrigin( tile_index.GetType()); } -// Emits the tile with a given tile_index, by calculating the tight bounds for -// each dimension of the tile and then calling tile_generator. -static void EmitOneTileForTileIndex( - const IrArray::Index& tile_index, llvm::Type* index_ty, llvm::Value* y, - llvm::Value* x, const KernelMappingScheme& mapping_scheme, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* b_, - IrEmitterUnnested::TileElementGenerator tile_generator) { +llvm::Value* IrEmitterUnnested::EmitTilingKernel( + const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, + TileElementGenerator tile_element_generator) { absl::Span dims_in_tile = mapping_scheme.GetDimensionsInTiles(); + absl::Span dims_in_block = + mapping_scheme.GetDimensionsInBlocks(); absl::Span dimensions_in_elements = mapping_scheme.GetDimensionsInElements(); - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + + auto constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - std::vector output_tile_bounds(3); - for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; - ++i) { - int64 tile_size_for_dim = mapping_scheme.GetTileSizeForDimension(i); - // Only last row or column may not have full size. - llvm::Value* is_last_row = b_->CreateICmpEQ( - tile_index[i], index_typed_constant(dims_in_tile[i] - 1)); - int64 partial_row_size = - dimensions_in_elements[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; - output_tile_bounds[i] = - b_->CreateSelect(is_last_row, index_typed_constant(partial_row_size), - index_typed_constant(tile_size_for_dim), "tile_bound"); - } - IrArray::Index tile_origin = - GetElementIndexForTileOrigin(tile_index, mapping_scheme, b_); - tile_generator(y, x, tile_origin, "output", output_tile_bounds[1], - output_tile_bounds[2], ksl); -} -static IrArray::Index GetStartingBlockIdx( - const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, - llvm::IRBuilder<>* b_) { - llvm::Value* block_id = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(), - llvm::cast(block_id)); - llvm::Value* linear_block_id = - b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); - return IrArray::Index( - linear_block_id, - ShapeUtil::MakeShapeWithDescendingLayout( - PRED /*arbitrary*/, mapping_scheme.GetDimensionsInBlocks()), - b_); -} - -static IrArray::Index GetStartingBlockForDimZ( - const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, - llvm::IRBuilder<>* b_) { - const IrArray::Index starting_block = - GetStartingBlockIdx(mapping_scheme, index_ty, b_); - std::vector multidim; - multidim.reserve(3); - for (int i = 0; i < 3; ++i) { - multidim.push_back(b_->CreateMul( - starting_block[i], - llvm::ConstantInt::get( - starting_block[i]->getType(), - mapping_scheme.GetNumberOfTilesInOneBlockForDimension(i)), - "block_origin." + std::to_string(i))); - } - return IrArray::Index(multidim, mapping_scheme.GetDimensionsInTiles(), - starting_block.GetType()); -} - -void IrEmitterUnnested::EmitTilingKernel( - const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, - TileElementGenerator tile_element_generator, - KernelPrologueGenerator kernel_prologue_generator, - KernelEpilogueGenerator kernel_epilogue_generator) { - // Calculate (y, x) coordinate of the thread in the 2D view of thread block + // Calculate (y, x) coordinates respectively in the 2D view of thread block, // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic( gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_); @@ -2552,44 +2482,83 @@ void IrEmitterUnnested::EmitTilingKernel( index_ty, mapping_scheme.GetNumberOfThreadsForDimensionX()); llvm::Value* x = b_.CreateURem(thread_id_int, num_thread_x, "thread.x"); llvm::Value* y = b_.CreateUDiv(thread_id_int, num_thread_x, "thread.y"); - llvm::Value* lane_id = - mapping_scheme.GetNumberOfThreadsForDimensionX() == kWarpSize ? x - : nullptr; + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { - return EmitOneTileForTileIndex(tile_index, index_ty, y, x, mapping_scheme, - &ksl, &b_, tile_element_generator); + // Calculate the starting tile. + const IrArray::Index starting_tile = [&]() { + llvm::Value* block_id = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdx, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(), + llvm::cast(block_id)); + llvm::Value* linear_block_id = + b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); + IrArray::Index starting_block( + linear_block_id, + ShapeUtil::MakeShapeWithDescendingLayout( + PRED /*arbitrary*/, mapping_scheme.GetDimensionsInBlocks()), + &b_); + + std::vector multidim; + multidim.reserve(3); + for (int i = 0; i < 3; ++i) { + multidim.push_back( + b_.CreateMul(starting_block[i], + llvm::ConstantInt::get(starting_block[i]->getType(), + mapping_scheme.BlockSize(i)), + "block_origin." + std::to_string(i))); + } + return IrArray::Index(multidim, mapping_scheme.GetDimensionsInTiles(), + starting_block.GetType()); + }(); + + auto emit_tile = [&](const IrArray::Index& tile_index) { + std::vector output_tile_bounds(3); + for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; + ++i) { + int64 tile_size_for_dim = mapping_scheme.GetTileSizeForDimension(i); + // Only last row or column may not have full size. + llvm::Value* is_last_row = + b_.CreateICmpEQ(tile_index[i], constant(dims_in_tile[i] - 1)); + int64 partial_row_size = + dimensions_in_elements[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; + output_tile_bounds[i] = + b_.CreateSelect(is_last_row, constant(partial_row_size), + constant(tile_size_for_dim), "tile_bound"); + } + IrArray::Index tile_origin = + GetElementIndexForTileOrigin(tile_index, mapping_scheme, &b_); + tile_element_generator(y, x, tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], &ksl); }; - const IrArray::Index starting_tile_for_dim_z = - GetStartingBlockForDimZ(mapping_scheme, index_ty, &b_); + int dim_z = KernelMappingScheme::DimZ; - auto emit_tiles_for_block_dim = - [&](const string& loop_name, const IrArray::Index& starting_tile, - int dim_id, - const std::function - emit_next_block_dim) { - EmitTilesForBlockDim(mapping_scheme, &ksl, index_ty, loop_name, - starting_tile, dim_id, &b_, emit_next_block_dim); - }; + if (mapping_scheme.BlockSize(dim_z) == 1) { + emit_tile(starting_tile); + } else { + llvm::Value* starting_tile_index_for_dim = starting_tile[dim_z]; + llvm::Value* block_size_for_dim = constant(mapping_scheme.BlockSize(dim_z)); + llvm::Value* block_id_for_dim = + b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); + llvm::Value* last_block_for_dim = constant(dims_in_block[dim_z] - 1); + llvm::Value* last_block_size_for_dim = + constant(dims_in_tile[dim_z] - + (dims_in_block[dim_z] - 1) * mapping_scheme.BlockSize(dim_z)); - kernel_prologue_generator(lane_id); - - // Emit the three dimensional block of tiles. - emit_tiles_for_block_dim( - "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ, - [&](const IrArray::Index& starting_tile_for_dim_y) { - emit_tiles_for_block_dim( - "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY, - [&](const IrArray::Index& starting_tile_for_dim_x) { - emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x, - KernelMappingScheme::DimX, - emit_one_tile_for_tile_index); + llvm::Value* num_tiles_in_block = + b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim), + last_block_size_for_dim, block_size_for_dim); + ksl.For("loop_z", + /*start=*/constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_z, &b_); + emit_tile(tile_index); }); - }); - - kernel_epilogue_generator(lane_id); + } + return x; } // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose @@ -2744,23 +2713,20 @@ void IrEmitterUnnested::EmitHlo021Tile( } }; - KernelPrologueGenerator hlo021_prologue = [&](llvm::Value* /*lane_id*/) { - // For multioutput fusion, one thread needs to output a tuple - // with pointers to all the individual outputs. We could do this - // at any point in the kernel, but we do it at the beginning in - // the hopes of reducing register pressure, since we touch - // threadIdx.x and blockIdx.x at the beginning of the kernel - // *anyway*. - if (hlo->IsMultiOutputFusion()) { - KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), - ConstructIrArrayForOutputs(*hlo), &b_); - }); - } - }; - KernelEpilogueGenerator epilogue_generator = [](llvm::Value* /*lane_id*/) {}; - EmitTilingKernel(mapping_scheme, index_type, tile_generator, hlo021_prologue, - epilogue_generator); + // For multioutput fusion, one thread needs to output a tuple + // with pointers to all the individual outputs. We could do this + // at any point in the kernel, but we do it at the beginning in + // the hopes of reducing register pressure, since we touch + // threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), + ConstructIrArrayForOutputs(*hlo), &b_); + }); + } + + EmitTilingKernel(mapping_scheme, index_type, tile_generator); UpdateLaunchDimensions(launch_dimensions, kernel_thunk, ir_emitter_context_->llvm_module()); } @@ -3093,14 +3059,13 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( int64 num_threads_y = 1; bool dilated_x = true; if (is_row_reduction) { + num_threads_x = kWarpSize; if (dims_in_elem[1] == 1) { // Scalar reduction is handled differently than the other kind of row // reduction. CHECK_EQ(dims_in_elem[0], 1); tile_size_x = kWarpSize * 16; - num_threads_x = kWarpSize; } else { - num_threads_x = kWarpSize; if (dims_in_elem[2] % (kWarpSize * 64) == 0) { tile_size_x = kWarpSize * 64; } else { @@ -3209,6 +3174,14 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( ReductionCodegenInfo reduction_info = ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + const KernelMappingScheme& mapping_scheme = + reduction_info.GetKernelMappingScheme(); + LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), + mapping_scheme.GetThreadsPerBlock()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); + EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions, + index_ty); EmitElementFunction emit_reduction_tile = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { @@ -3217,15 +3190,7 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( reducers, x_iter_num); }; - const auto& mapping_scheme = reduction_info.GetKernelMappingScheme(); - LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), - mapping_scheme.GetThreadsPerBlock()); - llvm::Type* index_ty = - reduction_info.IsRowReduction() - ? GetIndexTypeForKernel(unnested_hlo, - launch_dimensions.launch_bound(), &b_) - : b_.getInt64Ty(); - EmitTilingKernel( + llvm::Value* lane_id = EmitTilingKernel( mapping_scheme, index_ty, /*tile_element_generator=*/ [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, @@ -3234,18 +3199,10 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( EmitTiledElementalCodeWithBoundsCheck( reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, &b_, y, x, tile_height, tile_width, emit_reduction_tile); - }, - /*kernel_prologue_generator=*/ - [&](llvm::Value* /*lane_id*/) { - EmitPrologueForReduction(unnested_hlo, &reduction_info, - reduce_instructions, index_ty); - }, - /*kernel_epilogue_generator=*/ - [&](llvm::Value* lane_id) { - EmitEpilogueForReduction( - unnested_hlo, reduction_info, reduce_instructions, - reduction_output_shape_indices, reducers, lane_id); }); + EmitEpilogueForReduction(unnested_hlo, reduction_info, reduce_instructions, + reduction_output_shape_indices, reducers, lane_id); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index dbd6ce78bb3..3afcde86f28 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -52,12 +52,6 @@ namespace gpu { class IrEmitterUnnested : public IrEmitter, private ThunkEmitter::EmissionContext { public: - // A function object to prepare for the code generation for a tiling kernel. - using KernelPrologueGenerator = std::function; - - // A function object to finalize the code generation for a tiling kernel. - using KernelEpilogueGenerator = std::function; - // A function object to generate code to process one element in a tile. // // hlo: the instruction for which the code is generated for. @@ -71,6 +65,8 @@ class IrEmitterUnnested : public IrEmitter, const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num)>; + using ConstantGenerator = std::function; + // A function to generate the code to emit the entire tile. using TileElementGenerator = std::function output_instructions, @@ -243,6 +242,9 @@ class IrEmitterUnnested : public IrEmitter, absl::Span reducers, int64 x_iter_num); // Prepares for the code generation for a tile block of a reduction kernel. + // + // Create accumulator alloca's, populate them with initial values, and store + // inside reduction_info. void EmitPrologueForReduction( HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info, absl::Span reduce_instructions, @@ -253,7 +255,8 @@ class IrEmitterUnnested : public IrEmitter, ReductionCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter); - // Wraps up the code generation for a tile block of a reduction kernel. + // Wraps up the code generation for a tile block of a reduction kernel: write + // the calculated output into the output tensor. void EmitEpilogueForReduction( HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info, absl::Span reduce_instructions, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index 6e5bb5a1ba7..e25f1b66862 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -85,21 +85,18 @@ class KernelMappingScheme { dims_in_tiles_{dims_in_elems[0], CeilOfRatio(dims_in_elems[1], tile_size_y), CeilOfRatio(dims_in_elems[2], tile_size_x)}, - block_sizes_{block_size_z, 1, 1}, - dims_in_blocks_{CeilOfRatio(dims_in_elems[0], block_sizes_[0]), - dims_in_tiles_[1], dims_in_tiles_[2]}, + dims_in_blocks_{dims_in_elems[0] / block_size_z, dims_in_tiles_[1], + dims_in_tiles_[2]}, + block_size_z_{block_size_z}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), dilated_x_(is_dilated_x) { CHECK_EQ(tile_size_y % num_threads_y_, 0); CHECK_EQ(tile_size_x % num_threads_x_, 0); CHECK_EQ((dims_in_elems[0] % block_size_z), 0); - VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") - << "]"; - VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") - << "]"; - VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") - << "]"; + VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); + VLOG(10) << "dims_in_tiles_ = " << absl::StrJoin(dims_in_tiles_, ","); + VLOG(10) << "dims_in_blocks_ = " << absl::StrJoin(dims_in_blocks_, ","); if (!dilated_x_) { // dilated_x_=false is for the purpose of vectorization, which requires // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. @@ -127,13 +124,14 @@ class KernelMappingScheme { return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); } - int64 GetNumberOfTilesInOneBlock() const { - return absl::c_accumulate(block_sizes_, 1, std::multiplies()); - } + int64 GetNumberOfTilesInOneBlock() const { return block_size_z_; } - int64 GetNumberOfTilesInOneBlockForDimension(int d) const { + int64 BlockSize(int d) const { DCHECK(d >= DimZ && d <= DimX); - return block_sizes_[d]; + if (d == DimZ) { + return block_size_z_; + } + return 1; } int64 GetNumberOfBlocks() const { @@ -148,7 +146,6 @@ class KernelMappingScheme { return GetTileSizeForDimension(DimY); } - absl::Span GetBlockSizes() const { return block_sizes_; } int64 GetTileBlockSizeForDimension(int d) const { return dims_in_blocks_.at(d); } @@ -165,31 +162,31 @@ class KernelMappingScheme { private: // The number of elements in each dimension. - std::array dims_in_elems_; + const std::array dims_in_elems_; // The number of elements for each dimension of a tile. - std::array tile_sizes_; + const std::array tile_sizes_; // The number of tiles in each dimension. It is computed from dims_in_elem_ // and tile_sizes_. - std::array dims_in_tiles_; + const std::array dims_in_tiles_; - // The number of tiles for each dimension of a tile block. - std::array block_sizes_; // The number of blocks in each dimension of a tile block. It is computed from // dims_in_tile_ and block_sizes_. - std::array dims_in_blocks_; + const std::array dims_in_blocks_; + + const int64 block_size_z_; // Number of threads used to process elements in the X direction of a tile. - int64 num_threads_x_; + const int64 num_threads_x_; // Number of threads used to process elements in the Y direction of a tile. - int64 num_threads_y_; + const int64 num_threads_y_; // When num_threads_x threads process a total of tile_size_x elements in the // X dimension of a tile, each threads process n=tile_size_x/num_threads_x // elements. When dilated_x=false, the n elements processed by a thread are // contiguous. On the other hand, when dilated_x=true the n elements are // dilated by a factor of num_threads_x. - bool dilated_x_; + const bool dilated_x_; }; // Information to support the code generation for a tiled reduction kernel. diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index fa903f7233a..e525b4b1de9 100755 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -23,10 +23,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" @@ -50,7 +50,7 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" -#include "tensorflow/stream_executor/cuda/ptxas_utils.h" +#include "tensorflow/stream_executor/gpu/asm_compiler.h" namespace xla { namespace gpu { @@ -115,12 +115,11 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - // CudnnConvPadForTensorCores leaves behind unnecessary - // tuple/get-tuple-element pairs that TupleSimplifier fixes. - pipeline.AddPass(); - } + pipeline.AddPass(IsVoltaOrLater(*stream_exec)); + // CudnnConvPadForIntegerConvolutions and CudnnConvPadForTensorCores leaves + // behind unnecessary tuple/get-tuple-element pairs that TupleSimplifier + // fixes. + pipeline.AddPass(); // tf2xla bridge, DepthwiseConvolutionConverter and CudnnConvRewriter // introduces reshapes and transposes that can be eliminated using @@ -362,18 +361,18 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module, DumpToFileInDirOrStdout(*module, "ptx", ptx); } - std::vector cubin = - CompilePtxOrGetCachedResult(stream_exec, ptx, compute_capability.first, - compute_capability.second, module->config()); + std::vector cubin = CompileGpuAsmOrGetCachedResult( + stream_exec, ptx, compute_capability.first, compute_capability.second, + module->config()); return std::pair>(std::move(ptx), std::move(cubin)); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( +std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( se::StreamExecutor* stream_exec, const string& ptx, int cc_major, int cc_minor, const HloModuleConfig& hlo_module_config) { - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult"); tensorflow::profiler::TraceMe activity( "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo); bool inserted; @@ -401,9 +400,9 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = se::cuda::CompilePtx( - stream_exec->device_ordinal(), cache_ptx->c_str(), - PtxOptsFromConfig(hlo_module_config)); + StatusOr> maybe_cubin = + se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(), + PtxOptsFromConfig(hlo_module_config)); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index a7b38afb8ec..3098d5af25f 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -67,7 +67,7 @@ class NVPTXCompiler : public GpuCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult( + std::vector CompileGpuAsmOrGetCachedResult( se::StreamExecutor* stream_exec, const string& ptx, int cc_major, int cc_minor, const HloModuleConfig& hlo_module_config); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 117931e3398..8df21c3dfb1 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -220,10 +220,9 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, *kernel_args); } -se::cuda::PtxCompilationOptions PtxOptsFromConfig( - const HloModuleConfig& hlo_module_config) { - return se::cuda::PtxCompilationOptions( - hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations(), +se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { + return se::GpuAsmOpts( + hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(), hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); } @@ -245,10 +244,6 @@ template static void InitializeTypedBuffer(se::Stream* stream, se::DeviceMemoryBase buffer, int64* rng_state) { - static_assert( - std::is_floating_point::value || std::is_same::value, - "Unimplemented for integers yet."); - // Accesses to static variables are not locked, since the caller is already // in a critical section. static std::vector* host_buffer = [] { @@ -257,13 +252,23 @@ static void InitializeTypedBuffer(se::Stream* stream, // Default-seeded random numbers. std::mt19937 gen; for (auto& element : *ret) { - using RandomType = + // Only double gets random values in double. Other data types get random + // values in float then cast them to the target data types. + using RandomFloatingPointType = typename std::conditional::value, float, T>::type; + using RandomType = + typename std::conditional::value, float, + RandomFloatingPointType>::type; // Scale down the values for fp16 to have less overflows. auto upper_bound = RandomType(std::is_same::value ? 0.1 : 1.0); - element = T(UniformDistribution(RandomType(0), upper_bound, &gen)); + auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen); + // For float or double, it is between [0,1]. + // For fp16, it ranges between [0, 0.1]. + // For integer types, element is either 0 or 1 for less overflows + // especially for int8. + element = T(std::is_integral::value ? rand_val + 0.5 : rand_val); } return ret; }(); @@ -289,8 +294,8 @@ static void InitializeTypedBuffer(se::Stream* stream, } } -void InitializeFloatBuffer(se::Stream* stream, PrimitiveType buffer_type, - int64* rng_state, se::DeviceMemoryBase buffer) { +void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, + int64* rng_state, se::DeviceMemoryBase buffer) { switch (buffer_type) { case xla::F16: return InitializeTypedBuffer(stream, buffer, rng_state); @@ -300,6 +305,8 @@ void InitializeFloatBuffer(se::Stream* stream, PrimitiveType buffer_type, case xla::F64: case xla::C128: return InitializeTypedBuffer(stream, buffer, rng_state); + case xla::S8: + return InitializeTypedBuffer(stream, buffer, rng_state); default: LOG(FATAL) << "Unexpected type"; } diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 5da2931f049..3e2ae241a03 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/stream_executor/cuda/ptxas_utils.h" +#include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/kernel_spec.h" // Helper functions for interacting with StreamExecutor. @@ -74,9 +74,8 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, int64 threads_per_block, int64 block_count, se::Stream* stream); -// Create PtxCompilationOptions out of HloModuleConfig. -se::cuda::PtxCompilationOptions PtxOptsFromConfig( - const HloModuleConfig& hlo_module_config); +// Create GpuAsmOpts out of HloModuleConfig. +se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); // Initializes `buffer` with random data on `stream`. // `rng_state` is an inout parameter for the pseudorandom generator state. @@ -84,8 +83,8 @@ se::cuda::PtxCompilationOptions PtxOptsFromConfig( // // Precondition: `buffer_type` is a floating point type, `rng_state` needs to be // initalized to zero on the first use. -void InitializeFloatBuffer(se::Stream* stream, PrimitiveType buffer_type, - int64* rng_state, se::DeviceMemoryBase buffer); +void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, + int64* rng_state, se::DeviceMemoryBase buffer); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 9ec22df1b47..2b6383b6e3e 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -185,6 +185,12 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand, broadcast_shape, operand, broadcast_dimensions)); } +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + const Shape& shape) { + return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions()); +} + StatusOr MakeGetTupleElementHlo(HloInstruction* operand, int64 index) { HloComputation* computation = operand->parent(); @@ -224,6 +230,22 @@ HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) { return hlo; } +HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, + PrimitiveType type) { + CHECK_NE(hlo->shape().element_type(), type); + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateBitcastConvert(shape, hlo)); + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, + int64 iota_dimension) { + return computation->AddInstruction( + HloInstruction::CreateIota(shape, iota_dimension)); +} + StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index e56199650cb..986bed79af9 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -91,6 +91,9 @@ StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* MakeBroadcastHlo(HloInstruction* operand, absl::Span broadcast_dimensions, absl::Span result_shape_bounds); +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + const Shape& shape); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -107,6 +110,14 @@ StatusOr MakeConcatHlo( // the given primitive type. HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type); +// Creates a BitcastConvert HLO instruction. +HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, + PrimitiveType type); + +// Creates an Iota HLO instruction. +HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, + int64 iota_dimension); + // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 6025e6a7794..3c27366a8e6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -41,6 +41,21 @@ class HloCreationUtilsTest : public HloTestBase { *param = (*entry_computation)->parameter_instruction(0); return module; } + + std::unique_ptr CreateModuleWithProgramShape( + PrimitiveType primitive_type, absl::Span input_shape_dims, + absl::Span output_shape_dims, HloInstruction** param, + HloComputation** entry_computation, PrimitiveType primitive_type_output) { + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); + Shape output_shape = + ShapeUtil::MakeShape(primitive_type_output, output_shape_dims); + auto module = CreateNewVerifiedModule("test"); + *entry_computation = module->AddEntryComputation( + CreateComputationWithSignature({&input_shape}, output_shape, "entry") + .ValueOrDie()); + *param = (*entry_computation)->parameter_instruction(0); + return module; + } }; TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { @@ -222,5 +237,85 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } +TEST_F(HloCreationUtilsTest, MakeBitcastConvertToHlo_S32) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2, 2}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation, F32); + auto* input = module->entry_computation()->AddInstruction( + HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))); + + HloInstruction* output = MakeBitcastConvertToHlo(input, F32); + entry_computation->set_root_instruction(output); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR2({{0, 0}, {0, 0}})})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); +} + +TEST_F(HloCreationUtilsTest, MakeIotaHlo_I32) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation, F32); + HloInstruction* output = MakeIotaHlo(module->entry_computation(), + ShapeUtil::MakeShape(F32, {2, 2}), 0); + entry_computation->set_root_instruction(output); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0)})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {1.0f, 1.0f}})); +} + +TEST_F(HloCreationUtilsTest, MakeBroadcast_F32) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); + auto* input = MakeR0ConstantHlo(module->entry_computation(), 0); + HloInstruction* output = MakeBroadcastHlo(input, {}, {2, 2}); + entry_computation->set_root_instruction(output); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0f)})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); +} + +TEST_F(HloCreationUtilsTest, MakeBroadcast_Shape_I32) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); + auto* input = MakeR0ConstantHlo(module->entry_computation(), 0); + HloInstruction* output = + MakeBroadcastHlo(input, {}, ShapeUtil::MakeShape(S32, {2, 2})); + entry_computation->set_root_instruction(output); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 47100e3fc58..188f196fd3c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -189,8 +189,20 @@ bool HloDataflowAnalysis::Phi( for (const InstructionValueSet* input : inputs) { VLOG(5) << "input value set = " << input->ToString(); } - for (const InstructionValueSet* input : inputs) { - DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); + + if (bitcast_defines_value_) { + absl::c_for_each(inputs, [&](const InstructionValueSet* input) { + DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); + }); + } else { + const Shape& shape = instruction->shape(); + PrimitiveType ty = shape.element_type(); + bool is_array = shape.IsArray(); + absl::c_for_each(inputs, [&](const InstructionValueSet* input) { + DCHECK(ty == input->shape().element_type() && + (!is_array || ShapeUtil::ElementsIn(shape) == + ShapeUtil::ElementsIn(input->shape()))); + }); } bool changed = false; @@ -774,9 +786,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { std::forward_as_tuple(instruction), std::forward_as_tuple(instruction->shape())); - // Lambda to set the value set to define all values in the output of the - // instruction. - auto define_all_values = [this, &instruction](bool is_phi = false) { + // For each sub-shape of the instruction shape, add a new HloValue to its + // HloValueSet. + auto define_all_values = [this, &instruction]() { for (auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); @@ -784,16 +796,8 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } }; - // Lambda to set the value set to define only the top-level buffer in the - // output of the instruction. Any other values flow from the operands of - // the instruction (or from cross-computation dataflow). - auto define_top_level_only = [this, &instruction]() { - HloValue* value = - NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false); - GetValueSet(instruction, /*index=*/{}).AddValue(value); - }; - - // Lambda to set the value set at the given index of the output. + // Add a new HloValue to the HloValueSet corresponding to the given index + // of the instruction shape. auto define_value_at = [this, &instruction](const ShapeIndex& index) { HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); GetValueSet(instruction, index).AddValue(value); @@ -840,7 +844,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { case HloOpcode::kTuple: // These instructions only define their top-level values. Any other // values flow from their operands. - define_top_level_only(); + define_value_at(/*index=*/{}); break; case HloOpcode::kCopyDone: // CopyDone produces an element. Its output aliases its input tuple diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index a4966f9e2ba..79cd11f033e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -52,7 +52,7 @@ class HloDataflowAnalysis { const HloInstruction* instr, const HloInstruction* operand, const ShapeIndex& user_index)>; - // Run dataflow analysis on the given module. Parameters: + // Runs dataflow analysis on the given module. Parameters: // // ssa_form : If true then new values are defined at the merge points of // kWhile instructions. Abusing nomenclature somewhat, we call these "phi @@ -81,7 +81,7 @@ class HloDataflowAnalysis { bool ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const; - // Return the HloValue defined by 'instruction' at the given shape index of + // Returns the HloValue defined by 'instruction' at the given shape index of // its output. // // Precondition: ValueIsDefinedAt is true for this instruction and index. @@ -90,7 +90,7 @@ class HloDataflowAnalysis { HloValue& GetValueDefinedAt(const HloInstruction* instruction, const ShapeIndex& index = {}); - // Return the InstructionValueSet for the given instruction. + // Returns the InstructionValueSet for the given instruction. const InstructionValueSet& GetInstructionValueSet( const HloInstruction* instruction) const; InstructionValueSet& GetInstructionValueSet( @@ -100,7 +100,7 @@ class HloDataflowAnalysis { // a flattened set. HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; - // Return the HloValueSet for the given instruction at the given index or the + // Returns the HloValueSet for the given instruction at the given index or the // given position. const HloValueSet& GetValueSet(const HloInstruction* instruction, const ShapeIndex& index = {}) const; @@ -109,7 +109,7 @@ class HloDataflowAnalysis { HloValueSet& GetValueSet(const HloInstruction* instruction, const ShapeIndex& index = {}); - // Return the unique value in the HloValueSet at the given instruction and + // Returns the unique value in the HloValueSet at the given instruction and // shape index. CHECKs if the value set does not contain a exactly one value. const HloValue& GetUniqueValueAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { @@ -120,17 +120,17 @@ class HloDataflowAnalysis { return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); } - // Return the HloValue with the given Id. + // Returns the HloValue with the given Id. const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); - // Return the total number of HloValues. + // Returns the total number of HloValues. int64 value_count() const { return values_.size(); } - // Return a vector of all HloValues stabily sorted by HloValue::Id. + // Returns a vector of all HloValues stabily sorted by HloValue::Id. const std::vector& values() const { return values_vector_; } - // Return the call graph used for computing the dataflow. + // Returns the call graph used for computing the dataflow. const CallGraph& call_graph() const { return *call_graph_; } string ToString() const; @@ -164,10 +164,10 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Mark the HloValue with the given ID for deletion. + // Marks the HloValue with the given ID for deletion. void MarkValueForDeletion(HloValue::Id value_id); - // Delete all HloValues marked for deletion. Should be called after + // Deletes all HloValues marked for deletion. Should be called after // propagation is complete. void DeleteMarkedValues(); @@ -197,12 +197,13 @@ class HloDataflowAnalysis { bool UpdateWhileValueSet(HloInstruction* xla_while); bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); - // Propagate the dataflow through the module. + // Propagates the dataflow through the module. In particular, it propagates + // the HloValueSet from its defining instruction to the users of the + // instructions. void Propagate(); - // Return the result of the SSA Phi function applied to the given inputs at - // the given instruction. If skip_top_level is true, then the top level of the - // value set of 'instruction' is not modified. + // Returns the result of the SSA Phi function applied to the given inputs at + // the given instruction. bool Phi(HloInstruction* instruction, absl::Span inputs); @@ -217,7 +218,7 @@ class HloDataflowAnalysis { HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set = nullptr); - // Verify various invariants of the dataflow analysis. + // Verifies various invariants of the dataflow analysis. Status Verify() const; const HloModule& module_; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 211e5d830f5..bae803bdaa0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -410,9 +410,9 @@ Status HloEvaluator::HandleGetDimensionSize( } const Shape& shape = get_dimension_size->operand(0)->shape(); - Literal output(ShapeUtil::MakeShape(U32, {})); + Literal output(ShapeUtil::MakeShape(S32, {})); output.PopulateWithValue( - static_cast(shape.dimensions(get_dimension_size->dimension()))); + static_cast(shape.dimensions(get_dimension_size->dimension()))); evaluated_[get_dimension_size] = std::move(output); return Status::OK(); } @@ -1719,6 +1719,10 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { /*output_shape=*/shape); const Shape& operand_shape = operand.shape(); + if (ShapeUtil::IsZeroElementArray(operand_shape)) { + evaluated_[gather] = std::move(result); + return Status::OK(); + } auto gather_inner_loop_body = [&](absl::Span output_window_index, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 888434774bb..eff012065dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4154,13 +4154,13 @@ TEST_F(HloEvaluatorTest, GetDimensionSize) { HloModule Test ENTRY main { - size = u32[] parameter(0) + size = s32[] parameter(0) data = s32[4] parameter(1) sum = s32[4] add(data, data) - ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0} + ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0} } )"; TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); @@ -4174,12 +4174,12 @@ ENTRY main { DynamicDimensionInference::Run(m_.get())); evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference); - Literal size_arg = LiteralUtil::CreateR0(3); + Literal size_arg = LiteralUtil::CreateR0(3); Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg})); - EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); + EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); } // Check that we get a useful error if we pass inputs of the wrong shape. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 9487d955f31..6fa3f9fb34b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1389,9 +1389,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { *(accumulate_index_locations[i].second) = accumulate_index[i]; } + ElementwiseT lhs_val(lhs_literal.Get(lhs_index)); + ElementwiseT rhs_val(rhs_literal.Get(rhs_index)); result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); + ToArithmeticSafeType(lhs_val) * ToArithmeticSafeType(rhs_val); // If there are no contracting dimension accumulate_index_sizes is // empty, do not try to count down from -1 to 0 since it is and diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index 862b2029718..937c535e550 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -37,8 +37,10 @@ StatusOr ReplaceGetSize( TF_ASSIGN_OR_RETURN(auto legal_shape, ShapeInference::InferGetDimensionSizeShape( instr->operand(0)->shape(), instr->dimension())); - TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); - TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) + << "instr->shape() " << instr->shape().ToString() << " , " + << "legal_shape " << legal_shape.ToString(); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32)); HloInstruction* operand = instr->mutable_operand(0); int64 dim = instr->dimension(); HloInstruction* dynamic_size = @@ -46,9 +48,9 @@ StatusOr ReplaceGetSize( if (dynamic_size != nullptr) { TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); } else { - uint32 size = instr->operand(0)->shape().dimensions(dim); + int32 size = instr->operand(0)->shape().dimensions(dim); HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); } return true; diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc index bc240435be8..a0a06d53ea2 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -44,9 +44,9 @@ TEST_F(HloGetDimensionSizeRewriterTest, Ok) { HloModule _ ENTRY gds { p = s32[3,4] parameter(0) - size0 = u32[] get-dimension-size(p), dimensions={0} - size1 = u32[] get-dimension-size(p), dimensions={1} - ROOT mul = u32[] multiply(size0, size1) + size0 = s32[] get-dimension-size(p), dimensions={0} + size1 = s32[] get-dimension-size(p), dimensions={1} + ROOT mul = s32[] multiply(size0, size1) })") .ValueOrDie(); HloGetDimensionSizeRewriter pass; @@ -72,7 +72,7 @@ TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { HloModule _ ENTRY gds { p = f32[2,5] parameter(0) - ROOT gds = u32[] get-dimension-size(p), dimensions={2} + ROOT gds = s32[] get-dimension-size(p), dimensions={2} })") .ValueOrDie(); HloGetDimensionSizeRewriter pass; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 82f3b245590..c93f0106075 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1627,14 +1627,8 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } -static uint64 HashOperandRecursive(const HloInstruction* hlo) { - return hlo->Hash(HashOperandRecursive); -} - uint64 HloFusionInstruction::InnerHash() const { - // Use HashOperandRecursive to recursively compute hash on inner operands. - return fused_instructions_computation()->root_instruction()->Hash( - HashOperandRecursive); + return fused_instructions_computation()->root_instruction()->Hash(); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 78c48e036d6..acc077ab12d 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -77,7 +78,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, bool HloValue::operator==(const HloValue& other) const { bool equal = defining_instruction() == other.defining_instruction() && defining_index() == other.defining_index(); - // If the values are equal they most both be phi (or non phi). + // If the values are equal they must both be phi (or non phi). CHECK(!(equal && is_phi() != other.is_phi())); return equal; } @@ -87,17 +88,17 @@ bool HloValue::operator!=(const HloValue& other) const { } string HloValue::ToShortString() const { - string index_str = defining_instruction()->shape().IsTuple() - ? defining_index().ToString() - : ""; - return StrCat(id(), " ", is_phi_ ? "PHI " : "", - defining_instruction()->name(), index_str, " @", - (has_color() ? color().value() : -1)); + return absl::StrFormat( + "<%d %s%s%s%s>", id(), instruction()->name(), + instruction()->shape().IsTuple() ? index().ToString() : "", + is_phi() ? " (phi)" : "", + has_color() ? StrCat(" @", color().value()) : ""); } string HloValue::ToString(int indent) const { string indentation(indent, ' '); - string out = StrCat(indentation, ToShortString(), ", positions:\n"); + string out = + StrCat(indentation, ToShortString(), "\n", indentation, " positions:\n"); for (const HloPosition& position : positions()) { StrAppend(&out, indentation, " ", position.ToString(), "\n"); } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index b1a26b3b586..2606e2e4bf7 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -33,7 +33,9 @@ XlaInterpreterExecutor::XlaInterpreterExecutor( XlaInterpreterExecutor::~XlaInterpreterExecutor() {} -void *XlaInterpreterExecutor::Allocate(uint64 size) { return new char[size]; } +DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64 size) { + return DeviceMemoryBase(new char[size], size); +} void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes, diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 43493b6e154..ce94dbe7a6f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -68,7 +68,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return port::UnimplementedError("Not Implemented"); } - void *Allocate(uint64 size) override; + DeviceMemoryBase Allocate(uint64 size) override; void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes, uint64 size_bytes) override; void Deallocate(DeviceMemoryBase *mem) override; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index bf1df58f0b8..2fe3f9aa03e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -1964,6 +1965,41 @@ Status LayoutAssignment::ConstrainChannelLayouts( return Status::OK(); } +Status LayoutAssignment::PropagateMemorySpace(HloModule* module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); + for (auto buffer : alias_analysis->buffers()) { + // First go through values to collect the memory spaces. + int64 buffer_memory_space = Layout::kDefaultMemorySpace; + for (auto value : buffer.values()) { + const Shape& defining_shape = value->defining_position().shape(); + int64 memory_space = defining_shape.layout().memory_space(); + if (memory_space != Layout::kDefaultMemorySpace) { + if (buffer_memory_space != Layout::kDefaultMemorySpace && + memory_space != buffer_memory_space) { + return InternalError( + "Buffer %d (%s) has conflicting memory spaces: %d and %d.", + buffer.id(), value->ToShortString(), buffer_memory_space, + memory_space); + } + buffer_memory_space = memory_space; + } + } + + // If we encounter a memory space other than the default, then propagate all + // the positions with the buffer's memory space. + if (buffer_memory_space != Layout::kDefaultMemorySpace) { + for (auto value : buffer.values()) { + for (auto& position : value->positions()) { + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + shape->mutable_layout()->set_memory_space(buffer_memory_space); + } + } + } + } + return Status::OK(); +} + Status LayoutAssignment::PropagateComputationLayouts( HloComputation* computation, ComputationLayout* computation_layout) { ComputationLayout computed_computation_layout( @@ -2076,6 +2112,9 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), entry_computation_layout_)); + + TF_RETURN_IF_ERROR(PropagateMemorySpace(module)); + TF_RETURN_IF_ERROR(CheckLayouts(module)); // All layouts are reset then reassigned by this pass. diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 6a202837e14..a0f61fc416d 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -356,6 +356,10 @@ class LayoutAssignment : public HloModulePass { const HloInstruction* instruction, LayoutConstraints* constraints); + // Propagates the memory space defined in the entry computation to the called + // computations. + Status PropagateMemorySpace(HloModule* module); + // Chooses a layout of operand `operand_no` of `instruction` that minimizes // the cost of `instruction`. `output_layout` is the layout of `instruction`. // Returns null if it can't decide the best layout. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index c23b343d902..fa9a606568f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -188,6 +188,8 @@ StatusOr LhloDialectEmitter::CreateFunction( FuncOp::create(builder_.getUnknownLoc(), instr.name(), function_type); mlir_module_.push_back(function); function.addEntryBlock(); + OpBuilder op_builder(function.getBody()); + op_builder.create<::mlir::ReturnOp>(builder_.getUnknownLoc()); instruction_to_mlir_func_[&instr] = function; return function; } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index eef17132efa..22f60374ee9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -191,11 +191,9 @@ StatusOr> MlirCompiler::RunBackend( } // TODO(b/137624192): Add profiling support. - - return static_cast>( - absl::make_unique( - ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), - std::move(module), std::move(buffer_assignment), nullptr, nullptr)); + return {absl::make_unique( + ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), nullptr, nullptr)}; } StatusOr>> MlirCompiler::Compile( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc index dfa3af8c39f..c1d47fabbcd 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -64,8 +64,7 @@ void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() { // TODO(b/137624192): Remove failover once no longer in place. - FailoverCompiler* failover = - static_cast(backend().compiler()); + auto* failover = static_cast(backend().compiler()); return static_cast(failover->GetPrimary()); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7c4605125cf..b93c4358e6c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -652,9 +652,7 @@ Status ValidateDotDimensionNumbers( const int64 rhs_contracting_dimension = dimension_numbers.rhs_contracting_dimensions(i); if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension) || - lhs.is_dynamic_dimension(lhs_contracting_dimension) != - rhs.is_dynamic_dimension(rhs_contracting_dimension)) { + rhs.dimensions(rhs_contracting_dimension)) { return fail("Contracting dimension sizes do not match."); } } @@ -668,10 +666,7 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) || - lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.is_dynamic_dimension( - dimension_numbers.rhs_batch_dimensions(i))) { + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -726,13 +721,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); - output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else if (lhs.dimensions(i) == 1) { output_dimensions[i] = rhs.dimensions(i); - output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i); } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); - output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -740,6 +732,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(rhs)); } } + + // Merge dynamic dimensions from two shapes. + for (int64 i = 0; i < rhs.rank(); ++i) { + if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) { + output_dimensions_is_dynamic[i] = true; + } + } + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), output_dimensions, output_dimensions_is_dynamic); } @@ -888,11 +888,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). - return ShapeUtil::ChangeElementType( + Shape result = ShapeUtil::ChangeElementType( lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); - } - if (lhs.rank() == rhs.rank()) { + for (int64 i = 0; i < rhs.rank(); ++i) { + if (rhs.is_dynamic_dimension(i)) { + result.set_dynamic_dimension(i, true); + } + } + + return result; + + } else if (lhs.rank() == rhs.rank()) { return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using @@ -2201,14 +2208,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // TODO(b/119580730): Remove this restriction when very large dimension size // is needed. - if (shape.dimensions(dimension) > std::numeric_limits::max()) { + if (shape.dimensions(dimension) > std::numeric_limits::max()) { return InvalidArgument( "GetDimensionSize's input shape is %s, the %dth dimension exceeds the " - "UINT_MAX limit.", + "INT_MAX limit.", ShapeUtil::HumanString(shape), dimension); } - return ShapeUtil::MakeShape(U32, {}); + return ShapeUtil::MakeShape(S32, {}); } /* static */ StatusOr ShapeInference::InferWindowFromDimensions( @@ -2324,7 +2331,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, sizes.push_back((limit_index - start_index + stride - 1) / stride); } - return ShapeUtil::MakeShape(arg.element_type(), sizes); + std::vector is_dynamic(arg.rank()); + for (int64 i = 0; i < arg.dimensions_size(); ++i) { + is_dynamic[i] = arg.is_dynamic_dimension(i); + } + + return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); } /* static */ StatusOr ShapeInference::InferDynamicSliceShape( @@ -3061,10 +3073,10 @@ static Status ValidateGatherDimensionNumbers( } for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { - if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) { return InvalidArgument( - "Gather op can only collapse slice dims with bound 1, but bound is " - "%d for index %d at position %d.", + "Gather op can only collapse slice dims with bound 1 or 0, but bound " + "is %d for index %d at position %d.", slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], gather_dim_numbers.collapsed_slice_dims(i), i); } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index c241a4ac2ce..7ccdb869a91 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -2368,9 +2368,10 @@ TEST_F(ScatterGatherShapeInferenceTest, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only collapse slice dims with bound 1, " - "but bound is 29 for index 1 at position 0.")) + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 5f11fbf03be..c47145d076d 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -128,6 +128,17 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs, + const Shape& rhs) { + bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); @@ -507,17 +518,23 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return Shape::Equal().IgnoreLayout()(lhs, rhs); + return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs); + return Shape::Equal() + .IgnoreDynamicDimension() + .IgnoreElementType() + .IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs); + return Shape::Equal() + .IgnoreDynamicDimension() + .IgnoreFpPrecision() + .IgnoreLayout()(lhs, rhs); } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 44994b26ac1..dffabf75a9a 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -280,7 +280,6 @@ class ShapeUtil { if (SameElementType(a, b)) { return a.element_type(); } - CHECK(SameElementTypeIgnoringFpPrecision(a, b)); return primitive_util::BitWidth(a.element_type()) < primitive_util::BitWidth(b.element_type()) ? b.element_type() @@ -304,6 +303,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but does not compare the element type. + static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 020b062f6b1..4a59fe794c7 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include + #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -176,6 +177,27 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); } +TEST(ShapeUtilTest, EqualIgnoringElementType) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); + EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType( + ShapeUtil::MakeShapeWithLayout(S32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); + EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringElementType) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); +} + TEST(ShapeUtilTest, EqualDynamicShapes) { EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), @@ -195,7 +217,7 @@ TEST(ShapeUtilTest, CompatibleDynamicShapes) { EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a)); EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b)); - EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c)); + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_c)); } TEST(ShapeUtilTest, CompatibleTuples) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index ae0d70610be..ee823ce6364 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -745,7 +745,9 @@ xla_test( "no_oss", ], deps = [ + ":client_library_test_base", ":exhaustive_op_test_utils", + "//tensorflow/compiler/xla:util", ], ) @@ -765,7 +767,9 @@ xla_test( "no_oss", ], deps = [ + ":client_library_test_base", ":exhaustive_op_test_utils", + "//tensorflow/compiler/xla:util", ], ) @@ -785,7 +789,9 @@ xla_test( "no_oss", ], deps = [ + ":client_library_test_base", ":exhaustive_op_test_utils", + "//tensorflow/compiler/xla:util", ], ) @@ -1281,16 +1287,17 @@ xla_test( srcs = ["slice_test.cc"], shard_count = 40, deps = [ + ":client_library_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "//tensorflow/core/platform:types", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index a5e27cd67a7..916bbed252d 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1550,8 +1550,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { XlaBuilder b(TestName()); - std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; - std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values0 = {1.0f, -10.0f, -2.0f, 2.0f, + 3.2f, 4.0f, 0.5f, 5.7f}; + std::vector values1 = {0.0f, 10.0f, -4.0f, 1.0f, + 2.0f, 0.5f, -1.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ef800b8ef62..3e1b9508346 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -160,7 +160,7 @@ TEST_F(ComputeConstantTest, GetDimensionSize) { auto get_dimension_size = GetDimensionSize(add, 0); EXPECT_TRUE(IsConstant(get_dimension_size, &b)); - TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar( + TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar( client, get_dimension_size, &b)); EXPECT_EQ(value, 1); } @@ -178,7 +178,7 @@ TEST_F(ComputeConstantTest, MultipleGetDimensionSize) { EXPECT_TRUE(IsConstant(add_2, &b)); TF_ASSERT_OK_AND_ASSIGN(auto value, - ComputeConstantScalar(client, add_2, &b)); + ComputeConstantScalar(client, add_2, &b)); EXPECT_EQ(value, 2); } } diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index 80566be9085..6d8ddc199e2 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -39,7 +39,8 @@ static std::vector GetConv2DTestCases() { std::vector> config_options = { {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, - {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}, {256, 1, 2, 2}, + {144, 5, 3, 3}, {8, 48, 17, 1}, {16, 9, 5, 4}}; for (auto option : config_options) { int64 feature = option[0]; diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc index c0f8a0dc626..64372788be4 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc @@ -88,30 +88,41 @@ inline std::function AddEmptyBroadcastDimension( }; } -#define XLA_TEST_16BIT(test_name, ...) \ - XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ - __VA_ARGS__ \ +#if defined(BINARY_TEST_TARGET_F16) && defined(BINARY_TEST_TARGET_BF16) +#error "Can't define both BINARY_TEST_TARGET_F16 and BINARY_TEST_TARGET_BF16" +#endif + +#if defined(BINARY_TEST_TARGET_F16) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define BINARY_TEST_16BIT(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ + __VA_ARGS__ +#elif defined(BINARY_TEST_TARGET_BF16) && defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +#define BINARY_TEST_16BIT(test_name, ...) \ XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ __VA_ARGS__ +#else +#define BINARY_TEST_16BIT(test_name, ...) +#endif -XLA_TEST_16BIT(Add, { +BINARY_TEST_16BIT(Add, { auto host_add = [](float x, float y) { return x + y; }; Run(AddEmptyBroadcastDimension(Add), host_add); }) -XLA_TEST_16BIT(Sub, { +BINARY_TEST_16BIT(Sub, { auto host_sub = [](float x, float y) { return x - y; }; Run(AddEmptyBroadcastDimension(Sub), host_sub); }) // TODO(bixia): Mul fails with bfloat16 on CPU. -XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), { +BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), { auto host_mul = [](float x, float y) { return x * y; }; Run(AddEmptyBroadcastDimension(Mul), host_mul); }) // TODO(bixia): Div fails with bfloat16 on CPU. -XLA_TEST_16BIT(DISABLED_ON_CPU(Div), { +BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), { auto host_div = [](float x, float y) { return x / y; }; Run(AddEmptyBroadcastDimension(Div), host_div); }) @@ -146,19 +157,21 @@ T ReferenceMin(T x, T y) { return std::min(x, y); } -XLA_TEST_16BIT(Max, - { Run(AddEmptyBroadcastDimension(Max), ReferenceMax); }) +BINARY_TEST_16BIT(Max, { + Run(AddEmptyBroadcastDimension(Max), ReferenceMax); +}) -XLA_TEST_16BIT(Min, - { Run(AddEmptyBroadcastDimension(Min), ReferenceMin); }) +BINARY_TEST_16BIT(Min, { + Run(AddEmptyBroadcastDimension(Min), ReferenceMin); +}) // TODO(bixia): Pow fails with bfloat16 on CPU. -XLA_TEST_16BIT(DISABLED_ON_CPU(Pow), - { Run(AddEmptyBroadcastDimension(Pow), std::powf); }) +BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), + { Run(AddEmptyBroadcastDimension(Pow), std::powf); }) // TODO(bixia): Atan2 fails with bfloat16 on CPU. -XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2), - { Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); }) +BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2), + { Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); }) #if defined(BINARY_TEST_TARGET_F16) #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) @@ -224,35 +237,43 @@ class Exhaustive32BitOrMoreBinaryTest using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; -XLA_TEST_P(ExhaustiveF32BinaryTest, Add) { +#if defined(BINARY_TEST_TARGET_F32) +#define BINARY_TEST_FLOAT_32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_FLOAT_32(test_name, ...) +#endif + +BINARY_TEST_FLOAT_32(Add, { auto host_add = [](float x, float y) { return x + y; }; Run(AddEmptyBroadcastDimension(Add), host_add); -} +}) -XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) { +BINARY_TEST_FLOAT_32(Sub, { auto host_sub = [](float x, float y) { return x - y; }; Run(AddEmptyBroadcastDimension(Sub), host_sub); -} +}) // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) { +BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Mul), { auto host_mul = [](float x, float y) { return x * y; }; Run(AddEmptyBroadcastDimension(Mul), host_mul); -} +}) // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) { +BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Div), { auto host_div = [](float x, float y) { return x / y; }; Run(AddEmptyBroadcastDimension(Div), host_div); -} +}) -XLA_TEST_P(ExhaustiveF32BinaryTest, Max) { +BINARY_TEST_FLOAT_32(Max, { Run(AddEmptyBroadcastDimension(Max), ReferenceMax); -} +}) -XLA_TEST_P(ExhaustiveF32BinaryTest, Min) { +BINARY_TEST_FLOAT_32(Min, { Run(AddEmptyBroadcastDimension(Min), ReferenceMin); -} +}) // It is more convenient to implement Abs(complex) as a binary op than a unary // op, as the operations we currently support all have the same data type for @@ -261,16 +282,14 @@ XLA_TEST_P(ExhaustiveF32BinaryTest, Min) { // implement Abs(complex) as unary conveniently. // // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) { +BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), { auto host_abs_complex = [](float x, float y) { return std::abs(std::complex(x, y)); }; auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; Run(device_abs_complex, host_abs_complex); -} - -#if defined(BINARY_TEST_TARGET_F32) +}) INSTANTIATE_TEST_SUITE_P( SpecialValues, ExhaustiveF32BinaryTest, @@ -307,51 +326,55 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#if defined(BINARY_TEST_TARGET_F64) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define BINARY_TEST_FLOAT_64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_FLOAT_64(test_name, ...) #endif -XLA_TEST_P(ExhaustiveF64BinaryTest, Add) { +BINARY_TEST_FLOAT_64(Add, { auto host_add = [](double x, double y) { return x + y; }; Run(AddEmptyBroadcastDimension(Add), host_add); -} +}) -XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) { +BINARY_TEST_FLOAT_64(Sub, { auto host_sub = [](double x, double y) { return x - y; }; Run(AddEmptyBroadcastDimension(Sub), host_sub); -} +}) // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) { +BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Mul), { auto host_mul = [](double x, double y) { return x * y; }; Run(AddEmptyBroadcastDimension(Mul), host_mul); -} +}) // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) { +BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Div), { auto host_div = [](double x, double y) { return x / y; }; Run(AddEmptyBroadcastDimension(Div), host_div); -} +}) -XLA_TEST_P(ExhaustiveF64BinaryTest, Max) { +BINARY_TEST_FLOAT_64(Max, { Run(AddEmptyBroadcastDimension(Max), ReferenceMax); -} +}) -XLA_TEST_P(ExhaustiveF64BinaryTest, Min) { +BINARY_TEST_FLOAT_64(Min, { Run(AddEmptyBroadcastDimension(Min), ReferenceMin); -} +}) // TODO(bixia): Need to investigate the failure on CPU and file bugs. -XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) { +BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), { auto host_abs_complex = [](double x, double y) { return std::abs(std::complex(x, y)); }; auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; Run(device_abs_complex, host_abs_complex); -} +}) -#if defined(BINARY_TEST_TARGET_F64) - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) INSTANTIATE_TEST_SUITE_P( SpecialValues, ExhaustiveF64BinaryTest, ::testing::Combine( @@ -385,8 +408,6 @@ INSTANTIATE_TEST_SUITE_P( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); -#endif -#endif } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc index 3a14bb2d4cc..0a8fd82dd0c 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h" +#include "tensorflow/compiler/xla/util.h" #ifdef __FAST_MATH__ #error "Can't be compiled with fast math on" @@ -211,15 +213,54 @@ typedef Exhaustive32BitOrLessUnaryTest ExhaustiveF32UnaryTest; typedef Exhaustive32BitOrLessUnaryTest ExhaustiveF16UnaryTest; typedef Exhaustive32BitOrLessUnaryTest ExhaustiveBF16UnaryTest; -#define XLA_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \ - XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ - __VA_ARGS__ \ - XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ - __VA_ARGS__ \ - XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ - __VA_ARGS__ +#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER) +#define NEED_UNARY_F32 true +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define NEED_UNARY_F16 true +#else +#define NEED_UNARY_F16 false +#endif +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +#define NEED_UNARY_BF16 true +#else +#define NEED_UNARY_BF16 false +#endif +#else +#define NEED_UNARY_F32 false +#define NEED_UNARY_F16 false +#define NEED_UNARY_BF16 false +#endif -XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, { +#if NEED_UNARY_F32 +#define UNARY_TEST_F32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F32(test_name, ...) +#endif + +#if NEED_UNARY_F16 +#define UNARY_TEST_F16(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F16(test_name, ...) +#endif + +#if NEED_UNARY_BF16 +#define UNARY_TEST_BF16(test_name, ...) \ + XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_BF16(test_name, ...) +#endif + +#define UNARY_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \ + UNARY_TEST_F32(test_name, __VA_ARGS__) \ + UNARY_TEST_F16(test_name, __VA_ARGS__) \ + UNARY_TEST_BF16(test_name, __VA_ARGS__) + +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; @@ -227,7 +268,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, { Run(Log, std::log, error_spec_gen); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log1p, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; @@ -235,7 +276,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, { Run(Log1p, std::log1p, error_spec_gen); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Exp, { // When x < -105, the true value of exp(x) is smaller than the smallest F32, // so exp(x) should return exactly 0. We want our implementation of exp to // return exactly 0 as well, as not doing so implies either that our @@ -266,7 +307,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, { } }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Expm1, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (ty_ == F32) { error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; }; @@ -292,7 +333,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, { // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but // this *did* find a bug, namely that some backends were assuming sqrt(x) == // pow(x, 0.5), but this is not true for x == -inf. -XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); }; // TODO(b/123837116): Enable the test for all values after fixing the bug. if (platform_ != "Host" && platform_ != "CUDA") { @@ -306,12 +347,12 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, { Run( Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "Host" || platform_ == "CUDA") { error_spec_gen = +[](NativeT x) { @@ -349,11 +390,11 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) { XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } -XLA_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cosh, { // Our cosh implementation incorrectly overflows to inf for +/-89.4159851. // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // max-float, so we deem this acceptable. @@ -374,7 +415,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, { Run(Cosh, host_cosh); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { // Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851. // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // max-float, so we deem this acceptable. @@ -395,7 +436,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { Run(Sinh, host_sinh); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Tanh, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Tanh, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "CUDA") { error_spec_gen = +[](NativeT x) { @@ -429,62 +470,68 @@ void Exhaustive32BitOrLessUnaryTest::SetParamsForSinCosTan() { } } -XLA_TEST_P(ExhaustiveF32UnaryTest, Cos) { +UNARY_TEST_F32(Cos, { SetParamsForSinCosTan(); Run( Cos, std::cos, +[](NativeT) { return ErrorSpec{0.001, 0.001}; }); -} -XLA_TEST_P(ExhaustiveF16UnaryTest, Cos) { - SetParamsForSinCosTan(); - Run(Cos, std::cos); -} -XLA_TEST_P(ExhaustiveBF16UnaryTest, Cos) { - SetParamsForSinCosTan(); - Run(Cos, std::cos); -} +}) -XLA_TEST_P(ExhaustiveF32UnaryTest, Sin) { +UNARY_TEST_F16(Cos, { + SetParamsForSinCosTan(); + Run(Cos, std::cos); +}) + +UNARY_TEST_BF16(Cos, { + SetParamsForSinCosTan(); + Run(Cos, std::cos); +}) + +UNARY_TEST_F32(Sin, { SetParamsForSinCosTan(); Run( Sin, std::sin, +[](NativeT) { return ErrorSpec{0.001, 0.001}; }); -} -XLA_TEST_P(ExhaustiveF16UnaryTest, Sin) { - SetParamsForSinCosTan(); - Run(Sin, std::sin); -} -XLA_TEST_P(ExhaustiveBF16UnaryTest, Sin) { - SetParamsForSinCosTan(); - Run(Sin, std::sin); -} +}) -XLA_TEST_P(ExhaustiveF32UnaryTest, Tan) { +UNARY_TEST_F16(Sin, { + SetParamsForSinCosTan(); + Run(Sin, std::sin); +}) + +UNARY_TEST_BF16(Sin, { + SetParamsForSinCosTan(); + Run(Sin, std::sin); +}) + +UNARY_TEST_F32(Tan, { SetParamsForSinCosTan(); Run( Tan, std::tan, +[](NativeT) { return ErrorSpec{0.001, 0.001}; }); -} -XLA_TEST_P(ExhaustiveF16UnaryTest, Tan) { +}) + +UNARY_TEST_F16(Tan, { SetParamsForSinCosTan(); Run(Tan, std::tan); -} -XLA_TEST_P(ExhaustiveBF16UnaryTest, Tan) { +}) + +UNARY_TEST_BF16(Tan, { SetParamsForSinCosTan(); Run(Tan, std::tan); -} +}) // TODO(jlebar): Enable these. -// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); } -// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); } +// UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); } +// UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); } -XLA_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Digamma, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ != "Host" && platform_ != "CUDA") { // TODO(b/123956399): This is a fairly high error, significantly higher than @@ -514,7 +561,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, { } }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, { // Our implementation gets within 0.0001 rel error except for ~20 denormal // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma. ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); @@ -545,9 +592,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, { Run(Lgamma, host_lgamma, error_spec_gen); }) -XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); }) - -#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); }) INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, ::testing::ValuesIn(CreateExhaustiveF32Ranges())); @@ -562,8 +607,6 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); #endif -#endif - // Exhaustive test for unary operations for double. // // Test parameter is a tuple containing @@ -594,42 +637,51 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, } }; -XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); } +#if defined(UNARY_TEST_TARGET_F64) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define UNARY_TEST_FLOAT_64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_FLOAT_64(test_name, ...) +#endif -XLA_TEST_P(ExhaustiveF64UnaryTest, Log1p) { Run(Log1p, std::log1p); } +UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Exp) { Run(Exp, std::exp); } +UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Expm1) { Run(Expm1, std::expm1); } +UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); }) + +UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); }) // TODO(b/138385863): Turn on the test for GPU after fixing the bug. -XLA_TEST_P(ExhaustiveF64UnaryTest, DISABLED_ON_GPU(PowOneHalf)) { +UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), { Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, +[](double x) { return std::pow(x, 0.5); }); -} +}) -XLA_TEST_P(ExhaustiveF64UnaryTest, Rsqrt) { +UNARY_TEST_FLOAT_64(Rsqrt, { Run( Rsqrt, +[](double x) { return 1 / std::sqrt(x); }); -} +}) -XLA_TEST_P(ExhaustiveF64UnaryTest, Sqrt) { Run(Sqrt, std::sqrt); } +UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Acosh) { Run(Acosh, std::acosh); } +UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Asinh) { Run(Asinh, std::asinh); } +UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Atanh) { Run(Atanh, std::atanh); } +UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Acos) { Run(Acos, std::acos); } +UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Asin) { Run(Asin, std::asin); } +UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Cosh) { Run(Cosh, std::cosh); } +UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Sinh) { Run(Sinh, std::sinh); } +UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) { +UNARY_TEST_FLOAT_64(Tanh, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "CUDA") { error_spec_gen = +[](NativeT x) { @@ -639,26 +691,24 @@ XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) { }; } Run(Tanh, std::tanh, error_spec_gen); -} +}) -XLA_TEST_P(ExhaustiveF64UnaryTest, Cos) { Run(Cos, std::cos); } +UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Sin) { Run(Sin, std::sin); } +UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Tan) { Run(Tan, std::tan); } +UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Round) { Run(Round, std::round); } +UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); }) -XLA_TEST_P(ExhaustiveF64UnaryTest, Erf) { +UNARY_TEST_FLOAT_64(Erf, { Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); -} +}) -XLA_TEST_P(ExhaustiveF64UnaryTest, Erfc) { +UNARY_TEST_FLOAT_64(Erfc, { Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); -} +}) -#if defined(UNARY_TEST_TARGET_F64) -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) INSTANTIATE_TEST_SUITE_P( SpecialValues, ExhaustiveF64UnaryTest, ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); @@ -672,8 +722,6 @@ INSTANTIATE_TEST_SUITE_P( LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest, ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( 4000000000ull, 16000000))); -#endif -#endif // T is the Primitive Type of the complex number // Test parameter is a tuple containing @@ -741,30 +789,38 @@ class ExhaustiveComplexUnaryTestBase typedef ExhaustiveComplexUnaryTestBase ExhaustiveC64UnaryTest; typedef ExhaustiveComplexUnaryTestBase ExhaustiveC128UnaryTest; -// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. -XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) { - Run(Log, [](complex64 x) { return std::log(x); }); -} +#if defined(UNARY_TEST_TARGET_COMPLEX) +#define UNARY_TEST_COMPLEX_64(test_name, ...) \ + XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_COMPLEX_64(test_name, ...) +#endif -XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) { +// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. +UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), { + Run(Log, [](complex64 x) { return std::log(x); }); +}) + +UNARY_TEST_COMPLEX_64(Sqrt, { Run(Sqrt, [](complex64 x) { return static_cast( std::sqrt(static_cast(x))); }); -} +}) -XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) { +UNARY_TEST_COMPLEX_64(Rsqrt, { Run(Rsqrt, [](complex64 x) { return static_cast( complex128(1, 0) / std::sqrt(static_cast(x))); }); -} +}) // The current libc++ implementation of the complex tanh function provides // less accurate results when the denomenator of a complex tanh is small, due // to floating point precision loss. To avoid this issue for complex64 numbers, // we cast it to and from a complex128 when computing tanh. -XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) { +UNARY_TEST_COMPLEX_64(Tanh, { SetParamsForTanh(); ErrorSpecGen error_spec_gen = +[](complex64 x) { // This implementation of Tanh becomes less accurate when the denominator @@ -781,9 +837,8 @@ XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) { return static_cast(std::tanh(static_cast(x))); }, error_spec_gen); -} +}) -#if defined(UNARY_TEST_TARGET_COMPLEX) INSTANTIATE_TEST_SUITE_P( F32SpecialValues, ExhaustiveC64UnaryTest, ::testing::Combine( @@ -816,10 +871,17 @@ INSTANTIATE_TEST_SUITE_P( 4000)), ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 4000)))); + +#if defined(UNARY_TEST_TARGET_COMPLEX) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define UNARY_TEST_COMPLEX_128(test_name, ...) \ + XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_COMPLEX_128(test_name, ...) #endif - -XLA_TEST_P(ExhaustiveC128UnaryTest, Log) { +UNARY_TEST_COMPLEX_128(Log, { // TODO(b/138578313): Enable the test for all values after fixing the bug. known_incorrect_fn_ = [&](int64 v) { double f = this->ConvertValue(v); @@ -827,18 +889,18 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Log) { std::abs(f) < 1.0e-300; }; Run(Log, [](complex128 x) { return std::log(x); }); -} +}) -XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) { +UNARY_TEST_COMPLEX_128(Sqrt, { // Similar to the Tanh bug. known_incorrect_fn_ = [&](int64 v) { double f = this->ConvertValue(v); return std::abs(f) > std::numeric_limits::max() / 2; }; Run(Sqrt, [](complex128 x) { return std::sqrt(x); }); -} +}) -XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) { +UNARY_TEST_COMPLEX_128(Rsqrt, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "CUDA") { // Edge case on CUDA backend where the Log of a complex number made up of @@ -856,16 +918,14 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) { Rsqrt, [](complex128 x) { return complex128(1, 0) / std::sqrt(x); }, error_spec_gen); -} +}) -XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) { +UNARY_TEST_COMPLEX_128(Tanh, { SetParamsForTanh(); Run( Tanh, +[](complex128 x) { return std::tanh(x); }); -} +}) -#if defined(UNARY_TEST_TARGET_COMPLEX) -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) INSTANTIATE_TEST_SUITE_P( SpecialValues, ExhaustiveC128UnaryTest, ::testing::Combine( @@ -898,7 +958,5 @@ INSTANTIATE_TEST_SUITE_P( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); -#endif -#endif } // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 3fb69419e73..c04c4ec3e9d 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -259,17 +259,31 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run(GetParam()); } +// TODO(b/69425338): The following tests are disable on GPU because they use +// too much GPU memory. +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F32)) { + Run(GetParam()); +} -XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F64)) { + Run(GetParam()); +} -XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U32)) { + Run(GetParam()); +} -XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S32)) { + Run(GetParam()); +} -XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U64)) { + Run(GetParam()); +} -XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S64)) { + Run(GetParam()); +} XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } @@ -315,8 +329,6 @@ INSTANTIATE_TEST_CASE_P( SliceR1TestDataToString ); -// TODO(b/69425338): This uses too much memory on GPU. -#ifndef XLA_TEST_BACKEND_GPU INSTANTIATE_TEST_CASE_P( SliceR1TestBigSlicesInstantiation, SliceR1LargeTest, @@ -330,7 +342,6 @@ INSTANTIATE_TEST_CASE_P( ), SliceR1TestDataToString ); -#endif INSTANTIATE_TEST_CASE_P( SliceStridedR1TestInstantiation, diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 8c11077d549..a3bc092ac83 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -211,7 +211,7 @@ message DebugOptions { int32 xla_force_host_platform_device_count = 102; // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). - bool xla_gpu_disable_ptxas_optimizations = 103; + bool xla_gpu_disable_gpuasm_optimizations = 103; // Enable fast math with eigen in the HLO evaluator. bool xla_hlo_evaluator_use_fast_path = 106; diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 87c920efa2b..c6ebd8594e9 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -507,7 +507,6 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index d6020e78667..f2be3bdb656 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -24,7 +24,6 @@ from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test rng = np.random.RandomState(0) @@ -46,17 +45,16 @@ class _AutoCorrelationTest(object): x_ph = array_ops.placeholder_with_default( input=x_, shape=x_.shape if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session() as sess: - # Setting normalize = True means we divide by zero. - auto_corr = sample_stats.auto_correlation( - x_ph, axis=1, center=False, normalize=False) - if self.use_static_shape: - self.assertEqual((2, 3), auto_corr.shape) - auto_corr_ = sess.run(auto_corr) - self.assertAllClose( - [[0., 0., 0.], - [1., 1., 1.]], auto_corr_) + with self.cached_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, center=False, normalize=False) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [1., 1., 1.]], auto_corr_) def test_constant_sequence_axis_0_max_lags_none_center_true(self): x_ = np.array([[0., 0., 0.], @@ -64,17 +62,16 @@ class _AutoCorrelationTest(object): x_ph = array_ops.placeholder_with_default( input=x_, shape=x_.shape if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session() as sess: - # Setting normalize = True means we divide by zero. - auto_corr = sample_stats.auto_correlation( - x_ph, axis=1, normalize=False, center=True) - if self.use_static_shape: - self.assertEqual((2, 3), auto_corr.shape) - auto_corr_ = sess.run(auto_corr) - self.assertAllClose( - [[0., 0., 0.], - [0., 0., 0.]], auto_corr_) + with self.cached_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, normalize=False, center=True) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [0., 0., 0.]], auto_corr_) def check_results_versus_brute_force( self, x, axis, max_lags, center, normalize): @@ -99,16 +96,15 @@ class _AutoCorrelationTest(object): x_ph = array_ops.placeholder_with_default( x, shape=x.shape if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session(): - auto_corr = sample_stats.auto_correlation( - x_ph, axis=axis, max_lags=max_lags, center=center, - normalize=normalize) - if self.use_static_shape: - output_shape = list(x.shape) - output_shape[axis] = max_lags + 1 - self.assertAllEqual(output_shape, auto_corr.shape) - self.assertAllClose(rxx, auto_corr.eval(), rtol=1e-5, atol=1e-5) + with self.cached_session(): + auto_corr = sample_stats.auto_correlation( + x_ph, axis=axis, max_lags=max_lags, center=center, + normalize=normalize) + if self.use_static_shape: + output_shape = list(x.shape) + output_shape[axis] = max_lags + 1 + self.assertAllEqual(output_shape, auto_corr.shape) + self.assertAllClose(rxx, auto_corr.eval(), rtol=1e-5, atol=1e-5) def test_axis_n1_center_false_max_lags_none(self): x = rng.randn(2, 3, 4).astype(self.dtype) @@ -166,20 +162,18 @@ class _AutoCorrelationTest(object): x = rng.randn(l).astype(self.dtype) x_ph = array_ops.placeholder_with_default( x, shape=(l,) if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session(): - rxx = sample_stats.auto_correlation( - x_ph, max_lags=l // 2, center=True, normalize=False) - if self.use_static_shape: - self.assertAllEqual((l // 2 + 1,), rxx.shape) - rxx_ = rxx.eval() - # OSS CPU FFT has some accuracy issues is not the most accurate. - # So this tolerance is a bit bad. - self.assertAllClose(1., rxx_[0], rtol=0.05) - # The maximal error in the rest of the sequence is not great. - self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) - # The mean error in the rest is ok, actually 0.008 when I tested it. - self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + with self.cached_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # OSS CPU FFT has some accuracy issues, so this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) def test_step_function_sequence(self): # x jumps to new random value every 10 steps. So correlation length = 10. @@ -187,43 +181,40 @@ class _AutoCorrelationTest(object): * np.ones((1, 10))).ravel().astype(self.dtype) x_ph = array_ops.placeholder_with_default( x, shape=(1000 * 10,) if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session(): - rxx = sample_stats.auto_correlation( - x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False) - if self.use_static_shape: - self.assertAllEqual((1000 * 10 // 2 + 1,), rxx.shape) - rxx_ = rxx.eval() - rxx_ /= rxx_[0] - # Expect positive correlation for the first 10 lags, then significantly - # smaller negative. - self.assertGreater(rxx_[:10].min(), 0) - self.assertGreater(rxx_[9], 5 * rxx_[10:20].mean()) - # RXX should be decreasing for the first 10 lags. - diff = np.diff(rxx_) - self.assertLess(diff[:10].max(), 0) + with self.cached_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((1000 * 10 // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + rxx_ /= rxx_[0] + # Expect positive correlation for the first 10 lags, then significantly + # smaller negative. + self.assertGreater(rxx_[:10].min(), 0) + self.assertGreater(rxx_[9], 5 * rxx_[10:20].mean()) + # RXX should be decreasing for the first 10 lags. + diff = np.diff(rxx_) + self.assertLess(diff[:10].max(), 0) def test_normalization(self): l = 10000 x = 3 * rng.randn(l).astype(self.dtype) x_ph = array_ops.placeholder_with_default( x, shape=(l,) if self.use_static_shape else None) - with spectral_ops_test_util.fft_kernel_label_map(): - with self.cached_session(): - rxx = sample_stats.auto_correlation( - x_ph, max_lags=l // 2, center=True, normalize=True) - if self.use_static_shape: - self.assertAllEqual((l // 2 + 1,), rxx.shape) - rxx_ = rxx.eval() - # Note that RXX[0] = 1, despite the fact that E[X^2] = 9, and this is - # due to normalize=True. - # OSS CPU FFT has some accuracy issues is not the most accurate. - # So this tolerance is a bit bad. - self.assertAllClose(1., rxx_[0], rtol=0.05) - # The maximal error in the rest of the sequence is not great. - self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) - # The mean error in the rest is ok, actually 0.008 when I tested it. - self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + with self.cached_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=True) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # Note that RXX[0] = 1, despite the fact that E[X^2] = 9, and this is + # due to normalize=True. + # OSS CPU FFT has some accuracy issues, so this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) class AutoCorrelationTestStaticShapeFloat32(test.TestCase, diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 88bf3792c55..ef8c760ad96 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -29,7 +29,6 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", "//tensorflow/python/eager:remote", ], diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 8080d954eb7..7a7cf712543 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -39,15 +39,6 @@ To use, at program startup, call `tf.compat.v1.enable_eager_execution()`. @@custom_gradient -@@add_execution_callback -@@clear_execution_callbacks -@@errstate -@@ExecutionCallback -@@inf_callback -@@inf_nan_callback -@@nan_callback -@@seterr - @@Iterator @@Saver @@restore_variables_on_create @@ -117,14 +108,6 @@ from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function -from tensorflow.python.eager.execution_callbacks import add_execution_callback -from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks -from tensorflow.python.eager.execution_callbacks import errstate -from tensorflow.python.eager.execution_callbacks import ExecutionCallback -from tensorflow.python.eager.execution_callbacks import inf_callback -from tensorflow.python.eager.execution_callbacks import inf_nan_callback -from tensorflow.python.eager.execution_callbacks import nan_callback -from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.eager.remote import connect_to_remote_host from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py index bcc25b8de89..d4f4d657975 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -87,7 +87,7 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase): ds = ds.batch(20) # Test on a single batch - features = ds.make_one_shot_iterator().get_next() + features = dataset_ops.make_one_shot_iterator(ds).get_next() # Tile the context features across the sequence features seq_layer, _ = sfc.sequence_input_layer(features, seq_cols) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index d48edc027a2..e4ed2c7841a 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -126,7 +126,7 @@ py_test( py_test( name = "estimators_test", - size = "small", + size = "medium", srcs = ["python/learn/estimators/estimators_test.py"], python_version = "PY2", srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 9132b2209bc..c762227b20b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1088,12 +1088,22 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, + max_wait_secs=self._config.session_creation_timeout_secs, config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) return loss + def latest_checkpoint(self): + """Finds the filename of the latest saved checkpoint file in `model_dir`. + + Returns: + The full path to the latest checkpoint or `None` if no checkpoint was + found. + """ + return checkpoint_management.latest_checkpoint(self.model_dir) + def _identity_feature_engineering_fn(features, labels): return features, labels diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index b51ea30959e..e435fd65702 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -243,7 +243,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): protocol=None, evaluation_master='', model_dir=None, - session_config=None): + session_config=None, + session_creation_timeout_secs=7200): """Constructor. The superclass `ClusterConfig` may set properties like `cluster_spec`, @@ -282,6 +283,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): the feature. log_step_count_steps: The frequency, in number of global steps, that the global step/sec will be logged during training. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. evaluation_master: the master on which to perform evaluation. model_dir: directory where model parameters, graph etc are saved. If `None`, will use `model_dir` property in `TF_CONFIG` environment @@ -290,8 +293,11 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): session_config: a ConfigProto used to set session parameters, or None. Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. - protocol: An optional argument which specifies the protocol used when - starting server. None means default to grpc. + session_creation_timeout_secs: Max time workers should wait for a session + to become available (on initialization or when recovering a session) + with MonitoredTrainingSession. Defaults to 7200 seconds, but users may + want to set a lower value to detect problems with variable / session + (re)-initialization more quickly. """ # Neither parent class calls super().__init__(), so here we have to # manually call their __init__() methods. @@ -332,6 +338,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self._model_dir = _get_model_dir(model_dir) + self._session_creation_timeout_secs = session_creation_timeout_secs @experimental def uid(self, whitelist=None): @@ -408,6 +415,10 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): def log_step_count_steps(self): return self._log_step_count_steps + @property + def session_creation_timeout_secs(self): + return self._session_creation_timeout_secs + def _count_ps(cluster_spec): """Counts the number of parameter servers in cluster_spec.""" diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index b53f0588b56..72d489c3514 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -34,19 +34,19 @@ tensorflow/core/lib/strings/ordered_code.cc tensorflow/core/lib/strings/proto_text_util.cc tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc +tensorflow/core/platform/default/env_time.cc +tensorflow/core/platform/default/load_library.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/default/mutex.cc +tensorflow/core/platform/default/port.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc +tensorflow/core/platform/error.cc tensorflow/core/platform/file_system.cc tensorflow/core/platform/file_system_helper.cc tensorflow/core/platform/numbers.cc tensorflow/core/platform/posix/env.cc -tensorflow/core/platform/posix/env_time.cc -tensorflow/core/platform/posix/error.cc -tensorflow/core/platform/posix/load_library.cc -tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index ab124959001..6543f09e40d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -410,14 +410,27 @@ def dynamic_decode(decoder, """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) + decoder_state_sequence_lengths = False if decoder.tracks_own_finished: next_finished = decoder_finished + lengths = getattr(decoder_state, "lengths", None) + if lengths is not None: + # sequence lengths are provided by decoder_state.lengths; overwrite + # our sequence lengths. + decoder_state_sequence_lengths = True + sequence_lengths = math_ops.cast(lengths, dtypes.int32) else: next_finished = math_ops.logical_or(decoder_finished, finished) - next_sequence_lengths = array_ops.where( - math_ops.logical_not(finished), - array_ops.fill(array_ops.shape(sequence_lengths), time + 1), - sequence_lengths) + + if decoder_state_sequence_lengths: + # Just pass something through the loop; at the next iteration we'll pull + # the sequence lengths from the decoder_state again. + next_sequence_lengths = sequence_lengths + else: + next_sequence_lengths = array_ops.where( + math_ops.logical_not(finished), + array_ops.fill(array_ops.shape(sequence_lengths), time + 1), + sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 737d6866283..edfc6639b0f 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -303,8 +303,7 @@ cc_library( hdrs = ["signature.h"], deprecation = "No longer supported. Switch to SavedModel immediately.", visibility = ["//visibility:public"], - deps = [ - ] + if_not_mobile([ + deps = if_not_mobile([ ":manifest_proto_cc", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 017d08f5f60..1161f52cfbc 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -1,5 +1,6 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. +# Placeholder for Google-internal load statements. load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 75f6d942dba..2dd6240cdf4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -106,11 +106,7 @@ load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") load( "//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", - "tf_additional_cloud_kernel_deps", - "tf_additional_cloud_op_deps", "tf_additional_core_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_additional_device_tracer_test_flags", "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", @@ -384,6 +380,7 @@ cc_library( name = "util_port", srcs = ["util/port.cc"], hdrs = ["util/port.h"], + copts = tf_copts(), visibility = [ "//tensorflow/core:__pkg__", "//tensorflow/python:__pkg__", @@ -726,20 +723,7 @@ cc_library( cc_library( name = "lib", hdrs = [ - "lib/hash/crc32c.h", - "lib/hash/hash.h", "lib/histogram/histogram.h", - "lib/io/buffered_inputstream.h", - "lib/io/compression.h", - "lib/io/inputstream_interface.h", - "lib/io/path.h", - "lib/io/proto_encode_helper.h", - "lib/io/random_inputstream.h", - "lib/io/record_reader.h", - "lib/io/record_writer.h", - "lib/io/table.h", - "lib/io/table_builder.h", - "lib/io/table_options.h", "lib/monitoring/collected_metrics.h", "lib/monitoring/collection_registry.h", "lib/monitoring/counter.h", @@ -755,6 +739,8 @@ cc_library( "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_headers", + "//tensorflow/core/lib/hash:legacy_lib_hash_all_headers", + "//tensorflow/core/lib/io:legacy_lib_io_headers", "//tensorflow/core/lib/math:math_util.h", "//tensorflow/core/lib/random:legacy_lib_random_headers", "//tensorflow/core/lib/strings:legacy_lib_string_headers", @@ -1445,7 +1431,7 @@ cc_library( ]) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib", "//tensorflow/compiler/tf2tensorrt:trt_op_libs", - ]) + tf_additional_cloud_op_deps(), + ]), alwayslink = 1, ) @@ -1609,7 +1595,7 @@ cc_library( "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", "//tensorflow/core/kernels/sparse:kernels", - ] + tf_additional_cloud_kernel_deps() + if_not_windows([ + ] + if_not_windows([ "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", @@ -1801,6 +1787,10 @@ filegroup( "//tensorflow/core/lib/core:legacy_lib_core_all_headers", "//tensorflow/core/lib/core:legacy_lib_core_all_srcs", "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", + "//tensorflow/core/lib/hash:legacy_lib_hash_all_headers", + "//tensorflow/core/lib/hash:legacy_lib_hash_all_srcs", + "//tensorflow/core/lib/io:legacy_lib_io_all_headers", + "//tensorflow/core/lib/io:legacy_lib_io_all_srcs", "//tensorflow/core/lib/random:legacy_lib_random_all_headers", "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", @@ -2336,7 +2326,7 @@ tf_proto_library_cc( name = "worker_proto", srcs = ["protobuf/worker.proto"], cc_api_version = 2, - protodeps = tf_additional_all_protos() + [], + protodeps = tf_additional_all_protos(), visibility = ["//visibility:public"], ) @@ -2393,6 +2383,8 @@ LIB_INTERNAL_PRIVATE_HEADERS = [ "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_all_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", + "//tensorflow/core/lib/hash:legacy_lib_hash_all_headers", + "//tensorflow/core/lib/io:legacy_lib_io_all_headers", "//tensorflow/core/lib/random:legacy_lib_random_all_headers", "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", "//tensorflow/core/lib/math:math_util.h", @@ -2411,14 +2403,8 @@ LIB_INTERNAL_PRIVATE_HEADERS = [ LIB_INTERNAL_PUBLIC_HEADERS = [ "//tensorflow/core/lib/core:legacy_lib_internal_core_headers", "//tensorflow/core/lib/gtl:legacy_lib_internal_public_gtl_headers", - "lib/hash/hash.h", - "lib/io/inputbuffer.h", - "lib/io/iterator.h", - "lib/io/snappy/snappy_inputbuffer.h", - "lib/io/snappy/snappy_outputbuffer.h", - "lib/io/zlib_compression_options.h", - "lib/io/zlib_inputstream.h", - "lib/io/zlib_outputbuffer.h", + "//tensorflow/core/lib/hash:legacy_lib_internal_public_headers", + "//tensorflow/core/lib/io:legacy_lib_internal_public_headers", "lib/monitoring/mobile_counter.h", "lib/monitoring/mobile_gauge.h", "lib/monitoring/mobile_sampler.h", @@ -2483,7 +2469,6 @@ cc_library( exclude = [ "**/*test*", "framework/variant.cc", - "lib/hash/crc32c_accelerate.cc", "lib/gif/**/*", "lib/jpeg/**/*", "lib/png/**/*", @@ -2493,6 +2478,8 @@ cc_library( "//tensorflow/core/platform:legacy_platform_lib_srcs", "//tensorflow/core/platform:legacy_lib_internal_srcs", "//tensorflow/core/lib/core:legacy_lib_core_all_srcs", + "//tensorflow/core/lib/hash:legacy_lib_internal_impl_srcs", + "//tensorflow/core/lib/io:legacy_lib_io_all_srcs", "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", "//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs", ], @@ -2529,7 +2516,9 @@ cc_library( # File compiled with extra flags to get cpu-specific acceleration. cc_library( name = "lib_hash_crc32c_accelerate_internal", - srcs = ["lib/hash/crc32c_accelerate.cc"], + srcs = [ + "//tensorflow/core/lib/hash:legacy_crc32_accelerate_srcs", + ], # -msse4.2 enables the use of crc32c compiler builtins. copts = tf_copts() + if_linux_x86_64(["-msse4.2"]), ) @@ -3412,28 +3401,6 @@ cc_library( alwayslink = 1, ) -tf_cuda_library( - name = "device_tracer", - srcs = [ - "//tensorflow/core/platform:legacy_device_tracer_srcs", - ], - copts = tf_copts(), - cuda_deps = tf_additional_device_tracer_cuda_deps(), - visibility = [ - "//tensorflow:internal", - ], - deps = [ - ":core_cpu_internal", - ":lib", - ":protos_all_cc", - "//tensorflow/core/profiler/internal:parse_annotation", - "//tensorflow/core/profiler/internal:profiler_interface", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/flags:flag", - ], - alwayslink = True, -) - tf_proto_library_cc( name = "replay_log_proto", srcs = ["protobuf/replay_log.proto"], @@ -3649,10 +3616,8 @@ cc_library( name = "lib_test_internal", testonly = 1, hdrs = [ - "lib/io/block.h", - "lib/io/block_builder.h", - "lib/io/format.h", "//tensorflow/core/lib/gtl:legacy_lib_test_internal_headers", + "//tensorflow/core/lib/io:legacy_lib_test_internal_headers", "//tensorflow/core/lib/random:legacy_lib_test_internal_headers", ], deps = [ @@ -3732,19 +3697,7 @@ tf_cc_tests( name = "low_level_library_tests", size = "small", srcs = [ - "lib/hash/crc32c_test.cc", - "lib/hash/hash_test.cc", "lib/histogram/histogram_test.cc", - "lib/io/buffered_inputstream_test.cc", - "lib/io/inputbuffer_test.cc", - "lib/io/inputstream_interface_test.cc", - "lib/io/path_test.cc", - "lib/io/random_inputstream_test.cc", - "lib/io/record_reader_writer_test.cc", - "lib/io/recordio_test.cc", - "lib/io/snappy/snappy_buffers_test.cc", - "lib/io/table_test.cc", - "lib/io/zlib_buffers_test.cc", "lib/monitoring/collection_registry_test.cc", "lib/monitoring/counter_test.cc", "lib/monitoring/gauge_test.cc", @@ -3753,6 +3706,8 @@ tf_cc_tests( "lib/wav/wav_io_test.cc", "//tensorflow/core/lib/core:legacy_lib_core_all_tests", "//tensorflow/core/lib/gtl:legacy_lib_gtl_tests", + "//tensorflow/core/lib/hash:legacy_lib_hash_all_tests", + "//tensorflow/core/lib/io:legacy_lib_io_all_tests", "//tensorflow/core/lib/math:math_util_test.cc", "//tensorflow/core/lib/random:legacy_lib_random_tests", "//tensorflow/core/lib/strings:legacy_low_level_library_tests", @@ -3789,6 +3744,7 @@ tf_cc_tests( "//third_party/eigen3", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", "@zlib_archive//:zlib", ], ) @@ -5259,36 +5215,6 @@ tf_cc_test( ], ) -tf_cc_test_gpu( - name = "device_tracer_test", - size = "small", - srcs = ["//tensorflow/core/platform:device_tracer_test.cc"], - args = - ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":device_tracer", - ":direct_session", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/profiler/internal:profiler_interface", - ], -) - tf_cc_tests( name = "common_runtime_input_colocation_exemption_registry_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_Bitcast.pbtxt b/tensorflow/core/api_def/base_api/api_def_Bitcast.pbtxt index f65ce366587..8d6b80094c7 100644 --- a/tensorflow/core/api_def/base_api/api_def_Bitcast.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Bitcast.pbtxt @@ -18,21 +18,23 @@ gives module error. For example, Example 1: -```python + >>> a = [1., 2., 3.] ->>> equality_bitcast = tf.bitcast(a,tf.complex128) -tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] ->>> equality_cast = tf.cast(a,tf.complex128) +>>> equality_bitcast = tf.bitcast(a, tf.complex128) +Traceback (most recent call last): +... +InvalidArgumentError: Cannot bitcast from 1 to 18 [Op:Bitcast] +>>> equality_cast = tf.cast(a, tf.complex128) >>> print(equality_cast) tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) -``` + Example 2: -```python + >>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) -``` + Example 3: -```python + >>> x = [1., 2., 3.] >>> y = [0., 2., 3.] >>> equality= tf.equal(x,y) @@ -44,10 +46,9 @@ tf.Tensor([False True True], shape=(3,), dtype=bool) tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) >>> print(equality_bitcast) tf.Tensor( -[[ 0 0 0 0] - [ 0 0 128 63] - [ 0 0 128 63]], shape=(3, 4), dtype=uint8) -``` + [[ 0 0 0 0] + [ 0 0 128 63] + [ 0 0 128 63]], shape=(3, 4), dtype=uint8) *NOTE*: Bitcast is implemented as a low-level cast, so machines with different endian orderings will give different results. diff --git a/tensorflow/core/api_def/base_api/api_def_BroadcastTo.pbtxt b/tensorflow/core/api_def/base_api/api_def_BroadcastTo.pbtxt index 669223df862..2af0ea31c62 100644 --- a/tensorflow/core/api_def/base_api/api_def_BroadcastTo.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BroadcastTo.pbtxt @@ -28,14 +28,13 @@ and works its way forward. For example, -```python >>> x = tf.constant([1, 2, 3]) >>> y = tf.broadcast_to(x, [3, 3]) ->>> sess.run(y) -array([[1, 2, 3], - [1, 2, 3], - [1, 2, 3]], dtype=int32) -``` +>>> print(y) +tf.Tensor( + [[1 2 3] + [1 2 3] + [1 2 3]], shape=(3, 3), dtype=int32) In the above example, the input Tensor with the shape of `[1, 3]` is broadcasted to output Tensor with shape of `[3, 3]`. diff --git a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt index abd2e67bceb..33de5f424c9 100644 --- a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt @@ -3,13 +3,13 @@ op { in_arg { name: "input" description: <