diff --git a/.bazelrc b/.bazelrc index a543ebbcd75..8be3dadaf4e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -46,7 +46,6 @@ # sycl_asan: # sycl_trisycl: # mkl: Enable full mkl support. -# mkl_open_source_only: Enable MKL support only using open source MKL libraries. # tensorrt: Enable Tensorrt support. # ngraph: Enable ngraph support. # numa: Enable numa using hwloc. @@ -140,13 +139,6 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt -# This config option is used to enable MKL-DNN open source library only, -# without depending on MKL binary version. -build:mkl_open_source_only --define=build_with_mkl_dnn_only=true -build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true -build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true -build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 - # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true @@ -248,6 +240,7 @@ build:windows --copt=/w # Tensorflow uses M_* math constants that only get defined by MSVC headers if # _USE_MATH_DEFINES is defined. build:windows --copt=/D_USE_MATH_DEFINES +build:windows --host_copt=/D_USE_MATH_DEFINES # Default paths for TF_SYSTEM_LIBS build:linux --define=PREFIX=/usr diff --git a/.gitignore b/.gitignore index eab8a64c63d..72cb418fe11 100644 --- a/.gitignore +++ b/.gitignore @@ -38,7 +38,9 @@ gradleBuild *.pbxproj *.xcworkspace /*.podspec -/tensorflow/lite/**/[ios|objc|swift]*/BUILD +/tensorflow/lite/**/ios/BUILD +/tensorflow/lite/**/objc/BUILD +/tensorflow/lite/**/swift/BUILD /tensorflow/lite/examples/ios/simple/data/*.tflite /tensorflow/lite/examples/ios/simple/data/*.txt Podfile.lock diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 0b31ca33d20..c5574793b74 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -154,7 +154,10 @@ tf_cuda_library( "c_api.h", ], copts = tf_copts(), - visibility = ["//tensorflow/c:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//third_party/llvm/llvm-project:__subpackages__", + ], deps = [ ":c_api_internal", ":tf_attrtype", @@ -698,4 +701,5 @@ tf_cuda_library( # TODO(b/74620627): remove when _USE_C_SHAPES is removed "//tensorflow/python:cpp_shape_inference_proto_cc", ], + alwayslink = 1, ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 06a6bc64e74..bc1fbd3fcf5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -774,7 +774,7 @@ extern "C" { static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, const char* op_type, const char* oper_name) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { return new TF_OperationDescription(graph, op_type, oper_name); } @@ -1032,7 +1032,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { Node* ret = nullptr; if (desc->graph->name_map.count(desc->node_builder.node_name())) { @@ -1706,7 +1706,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, const TF_ImportGraphDefOptions* opts, TF_ImportGraphDefResults* tf_results, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { const int last_node_id = graph->graph.num_node_ids(); tensorflow::ImportGraphDefResults results; status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a908fd131c1..a235ea0cf5a 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -51,7 +51,7 @@ Status ProcessInputs( const TF_Graph* fn_body, const char* fn_name, int ninputs, const TF_Output* inputs, std::vector* input_tensors, std::unordered_map>* input_nodes) - EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { input_tensors->reserve(ninputs); for (int i = 0; i < ninputs; ++i) { Node* node = &inputs[i].oper->node; @@ -87,7 +87,7 @@ Status ProcessInputs( Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, int noutputs, const TF_Output* outputs, std::vector* output_tensors) - EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { Node* node = &outputs[i].oper->node; @@ -111,7 +111,7 @@ Status ComputeBodyNodes( const TF_Operation* const* opers, const std::unordered_map>& input_nodes, std::vector* body_nodes) - EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { if (num_opers == -1) { for (const Node* node : fn_body->graph.op_nodes()) { const auto& iter = input_nodes.find(node); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 9e1b54f0029..32880378c2b 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -71,14 +71,14 @@ struct TF_Graph { TF_Graph(); tensorflow::mutex mu; - tensorflow::Graph graph GUARDED_BY(mu); + tensorflow::Graph graph TF_GUARDED_BY(mu); // Runs shape inference. - tensorflow::ShapeRefiner refiner GUARDED_BY(mu); + tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu); // Maps from name of an operation to the Node* in 'graph'. std::unordered_map name_map - GUARDED_BY(mu); + TF_GUARDED_BY(mu); // The keys of this map are all the active sessions using this graph. Each // value records whether the graph has been mutated since the corresponding @@ -94,8 +94,8 @@ struct TF_Graph { // TODO(b/74949947): mutations currently trigger a warning instead of a bad // status, this should be reverted when possible. tensorflow::gtl::FlatMap sessions - GUARDED_BY(mu); - bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph + TF_GUARDED_BY(mu); + bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that // will eventually contain the full while loop. @@ -123,7 +123,7 @@ struct TF_Session { tensorflow::Session* session; TF_Graph* const graph; - tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); + tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call @@ -169,9 +169,9 @@ struct TF_ApiDefMap { } #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); + tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock); #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) - bool update_docs_called GUARDED_BY(lock); + bool update_docs_called TF_GUARDED_BY(lock); tensorflow::mutex lock; }; @@ -210,10 +210,10 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, void RecordMutation(TF_Graph* graph, const TF_Operation& op, const char* mutation_type) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu); + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu); bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) - LOCKS_EXCLUDED(session->graph->mu, session->mu); + TF_LOCKS_EXCLUDED(session->graph->mu, session->mu); std::string getTF_OutputDebugString(TF_Output node); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 2749724d039..c25cb264ce7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -354,6 +354,7 @@ cc_library( "//tensorflow/core:lib", "@dlpack", ], + alwayslink = 1, ) # TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 65f37f3021f..96dc288f213 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1710,8 +1710,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, namespace { class CustomDeviceAPI : public tensorflow::CustomDevice { public: - CustomDeviceAPI(TFE_CustomDevice device, void* info, string name) - : device_(device), info_(info), name_(name) {} + CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, + string name) + : context_(context), device_(device), info_(info), name_(name) {} ~CustomDeviceAPI() override { device_.delete_device(info_); } @@ -1725,7 +1726,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { std::make_unique(tensor)}; TF_Status status; TFE_TensorHandle* result_handle = - device_.copy_tensor_to_device(&tensor_handle, &status, info_); + device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_); if (!status.status.ok()) return status.status; *result = tensorflow::down_cast( result_handle->handle.get()) @@ -1744,7 +1745,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { TFE_TensorHandle tensor_handle{ std::make_unique(tensor)}; TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( - &tensor_handle, target_device_name.c_str(), &status, info_); + context_, &tensor_handle, target_device_name.c_str(), &status, info_); if (!status.status.ok()) return status.status; *result = tensorflow::down_cast( result_handle->handle.get()) @@ -1768,7 +1769,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { std::vector outputs(*num_retvals); TF_Status status; TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str()); - device_.execute(inputs.size(), inputs.data(), op->Name().c_str(), + device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), &attributes, num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { @@ -1787,6 +1788,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { } private: + TFE_Context* context_; TFE_CustomDevice device_; void* info_; string name_; @@ -1794,8 +1796,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { } // namespace void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, - const char* device_name, void* device_info) { + const char* device_name, void* device_info, + TF_Status* status) { auto custom_device = - std::make_unique(device, device_info, device_name); - ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); + std::make_unique(ctx, device, device_info, device_name); + status->status = + ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index b0f0da5acef..c24735963d6 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -458,27 +458,29 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, size_t proto_len, TF_Status* status); -#define TFE_CUSTOM_DEVICE_VERSION 1 +#define TFE_CUSTOM_DEVICE_VERSION 2 // Struct to be filled in typedef struct TFE_CustomDevice { int version = TFE_CUSTOM_DEVICE_VERSION; // Method to copy a tensor to the custom device. - TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor, + TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, + TFE_TensorHandle* tensor, TF_Status* status, void* device_info) = nullptr; // Method to copy a tensor from the custom device to a target device. - TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor, + TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, + TFE_TensorHandle* tensor, const char* target_device_name, TF_Status* status, void* device_info); // Method to execute an operation. - void (*execute)(int num_inputs, TFE_TensorHandle** inputs, - const char* operation_name, const TFE_OpAttrs* attributes, - int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, - void* device_info); + void (*execute)(TFE_Context* context, int num_inputs, + TFE_TensorHandle** inputs, const char* operation_name, + const TFE_OpAttrs* attributes, int* num_outputs, + TFE_TensorHandle** outputs, TF_Status* s, void* device_info); // Method to delete a device. void (*delete_device)(void* device_info); @@ -503,11 +505,21 @@ typedef struct TFE_CustomDevice { // devices, so executing tf.functions which contain operations placed on custom // devices will fail. // +// `device_name` must not name an existing physical or custom device. It must +// follow the format: +// +// /job:/replica:/task:/device:: +// +// If the device is successfully registered, `status` is set to TF_OK. Otherwise +// the device is not usable. In case of a bad status, `device.delete_device` is +// still called on `device_info` (i.e. the caller does not retain ownership). +// // This API is highly experimental, and in particular is expected to change when // it starts supporting operations with attributes and when tf.function support // is added. void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, - const char* device_name, void* device_info); + const char* device_name, void* device_info, + TF_Status* status); TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 742844c3f75..b6e6369bb43 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -27,7 +27,6 @@ limitations under the License. namespace { struct LoggingDevice { - TFE_Context* ctx; tensorflow::string device_name; tensorflow::string underlying_device; // Set to true whenever a TensorHandle is copied onto the device @@ -48,7 +47,7 @@ void LoggedTensorDeallocator(void* data, size_t len, void* arg) { } TFE_TensorHandle* MakeLoggedTensorHandle( - TFE_Context* ctx, const tensorflow::string& logging_device_name, + TFE_Context* context, const tensorflow::string& logging_device_name, std::unique_ptr t, TF_Status* status) { std::vector shape(TFE_TensorHandleNumDims(t->tensor, status)); if (TF_GetCode(status) != TF_OK) return nullptr; @@ -58,23 +57,25 @@ TFE_TensorHandle* MakeLoggedTensorHandle( } auto dtype = TFE_TensorHandleDataType(t->tensor); return TFE_NewTensorHandleFromDeviceMemory( - ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(), + context, logging_device_name.c_str(), dtype, shape.data(), shape.size(), t.release(), 1, &LoggedTensorDeallocator, nullptr, status); } -TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor, +TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context, + TFE_TensorHandle* tensor, TF_Status* status, void* device_info) { LoggingDevice* dev = reinterpret_cast(device_info); TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice( - tensor, dev->ctx, dev->underlying_device.c_str(), status); + tensor, context, dev->underlying_device.c_str(), status); if (TF_GetCode(status) != TF_OK) return nullptr; auto dst = std::make_unique(t); *(dev->arrived_flag) = true; - return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst), + return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst), status); } -TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor, +TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context, + TFE_TensorHandle* tensor, const char* target_device_name, TF_Status* status, void* device_info) { @@ -83,13 +84,13 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor, return nullptr; } -void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs, - const char* operation_name, +void LoggingDeviceExecute(TFE_Context* context, int num_inputs, + TFE_TensorHandle** inputs, const char* operation_name, const TFE_OpAttrs* attributes, int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, void* device_info) { LoggingDevice* dev = reinterpret_cast(device_info); - TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s)); + TFE_Op* op(TFE_NewOp(context, operation_name, s)); if (TF_GetCode(s) != TF_OK) return; TFE_OpAddAttrs(op, attributes); TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); @@ -117,7 +118,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs, } for (int i = 0; i < *num_outputs; ++i) { auto logged_tensor = std::make_unique(unwrapped_outputs[i]); - outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name, + outputs[i] = MakeLoggedTensorHandle(context, dev->device_name, std::move(logged_tensor), s); } *(dev->executed_flag) = true; @@ -128,19 +129,19 @@ void DeleteLoggingDevice(void* device_info) { } void RegisterLoggingDevice(TFE_Context* context, const char* name, - bool* arrived_flag, bool* executed_flag) { + bool* arrived_flag, bool* executed_flag, + TF_Status* status) { TFE_CustomDevice custom_device; custom_device.copy_tensor_to_device = &CopyToLoggingDevice; custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice; custom_device.delete_device = &DeleteLoggingDevice; custom_device.execute = &LoggingDeviceExecute; LoggingDevice* device = new LoggingDevice; - device->ctx = context; device->arrived_flag = arrived_flag; device->executed_flag = executed_flag; device->device_name = name; device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0"; - TFE_RegisterCustomDevice(context, custom_device, name, device); + TFE_RegisterCustomDevice(context, custom_device, name, device, status); } TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { @@ -153,7 +154,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context, name, &arrived, &executed); + RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); ASSERT_FALSE(arrived); TFE_TensorHandle* hdevice = @@ -189,7 +191,9 @@ TEST(CUSTOM_DEVICE, ResetOperation) { bool executed = false; const char* custom_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed); + RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); std::unique_ptr reused_op( TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp); @@ -217,7 +221,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) { bool arrived = false; bool executed = false; const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - RegisterLoggingDevice(context.get(), name, &arrived, &executed); + RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle placed on the custom device. std::unique_ptr op( @@ -291,4 +296,103 @@ TEST(CUSTOM_DEVICE, MakeVariable) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); } +TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool arrived = false; + bool executed = false; + const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a variable handle placed on the custom device. + std::unique_ptr op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get()); + TFE_OpSetAttrString(op.get(), "container", "", 0); + TFE_OpSetAttrString(op.get(), "shared_name", "", 0); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + executed = false; + TFE_Execute(op.get(), &var_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + auto handle_cleaner = tensorflow::gtl::MakeCleanup( + [var_handle]() { TFE_DeleteTensorHandle(var_handle); }); + + // Assign to the variable, copying to the custom device. + std::unique_ptr one( + TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); + op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpAddInput(op.get(), one.get(), status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + + // Read the variable's value. + op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status.get()); + EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK) + << "Execution should fail because the variable is being used on the " + "wrong device."; + // Free the backing buffer for the variable. + op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + +TEST(CUSTOM_DEVICE, InvalidRegistrationError) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool arrived = false; + bool executed = false; + RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT) + << TF_Message(status.get()); + + const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) + << TF_Message(status.get()); + + RegisterLoggingDevice(context.get(), + "/job:localhost/replica:0/task:0/device:CPU:0", + &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS) + << TF_Message(status.get()); +} + } // namespace diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index 7da27e99d1f..33028ea6bd9 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -27,14 +27,10 @@ namespace { class DummyDevice : public DeviceBase { public: - DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} - bool RequiresRecordingAccessedTensors() const override { return save_; } + explicit DummyDevice(Env* env) : DeviceBase(env) {} Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { return cpu_allocator(); } - - private: - bool save_; }; void TestBitcastOp(Tensor* input_tensor, DataType out_type, @@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type, ASSERT_TRUE(status.ok()) << status.ToString(); OpKernelContext::Params params; - DummyDevice dummy_device(nullptr, false); + DummyDevice dummy_device(nullptr); params.device = &dummy_device; params.op_kernel = kernel.get(); gtl::InlinedVector inputs; diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 80e90e7cdf9..423302741de 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) { class DummyDevice : public DeviceBase { public: - DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} - bool RequiresRecordingAccessedTensors() const override { return save_; } + explicit DummyDevice(Env* env) : DeviceBase(env) {} Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { return cpu_allocator(); } - - private: - bool save_; }; TEST(TestKernel, TestInputAndOutputCount) { @@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) { { OpKernelContext::Params p; - DummyDevice dummy_device(nullptr, false); + DummyDevice dummy_device(nullptr); p.device = &dummy_device; p.step_id = 43; diff --git a/tensorflow/c/tf_tensor.h b/tensorflow/c/tf_tensor.h index 462fdc8b497..7ed4a9f754e 100644 --- a/tensorflow/c/tf_tensor.h +++ b/tensorflow/c/tf_tensor.h @@ -58,9 +58,9 @@ extern "C" { // start_offset: array[uint64] // data: byte[...] // -// The string length (as a varint), followed by the contents of the string -// is encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode -// facilitate this encoding. +// The string length (as a varint, start_offset[i + 1] - start_offset[i]), +// followed by the contents of the string is encoded at data[start_offset[i]]. +// TF_StringEncode and TF_StringDecode facilitate this encoding. typedef struct TF_Tensor TF_Tensor; diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index c4add1589e7..da2e12a4a06 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -41,7 +41,7 @@ class ClientSession::Impl { std::shared_ptr graph_; mutable mutex mu_; - mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; + mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0; }; ClientSession::ClientSession(const Scope& scope, const string& target) diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 6d92d05803d..ca2b5f956bf 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -114,14 +114,14 @@ class Coordinator { condition_variable wait_for_stop_; mutex mu_; - bool should_stop_ GUARDED_BY(mu_); + bool should_stop_ TF_GUARDED_BY(mu_); mutex status_lock_; - Status status_ GUARDED_BY(status_lock_); + Status status_ TF_GUARDED_BY(status_lock_); mutable mutex runners_lock_; std::vector> runners_ - GUARDED_BY(runners_lock_); + TF_GUARDED_BY(runners_lock_); TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); }; diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index d9ecd221493..4a748bfc924 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -119,8 +119,8 @@ class QueueRunner : public RunnerInterface { std::unique_ptr thread_pool_; mutex mu_; int runs_ = 0; - Status status_ GUARDED_BY(mu_); - Status enqueue_status_ GUARDED_BY(mu_); + Status status_ TF_GUARDED_BY(mu_); + Status enqueue_status_ TF_GUARDED_BY(mu_); std::unique_ptr counter_; Coordinator* coord_; @@ -131,7 +131,7 @@ class QueueRunner : public RunnerInterface { std::vector> callbacks_; mutable std::unique_ptr cg_mu_; - std::unique_ptr cost_graph_ GUARDED_BY(cg_mu_); + std::unique_ptr cost_graph_ TF_GUARDED_BY(cg_mu_); RunOptions run_options_; }; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index dfbea9c49eb..7f1590ff75d 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -37,6 +37,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "//tensorflow/compiler/mlir/lite/quantization/xla:quantize", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", @@ -64,6 +65,7 @@ cc_library( "@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep "@llvm-project//llvm:target", "@llvm-project//llvm:x86_code_gen", # fixdeps: keep + "//tensorflow/core:regexp_internal", ] + if_llvm_aarch64_available([ "//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep ]), diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 53150e991cc..4a4fec5a386 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config, } // Generates code implementing {Arg,Result}Names(), where T is one of -// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string -// literal in the array, with nullptr terminating the array. +// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style +// string literal in the array, with nullptr terminating the array. template string GenNameToIndexCode(const T& entries, bool generate) { // No need for a static array if we're not supposed to generate the data. @@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // Generate metadata. const string arg_names_code = GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + + auto variable_copy = config.variable(); + for (auto& var : variable_copy) { + if (var.name().empty()) { + var.set_name(var.node_name()); + } + } + const string variable_names_code = + GenNameToIndexCode(variable_copy, opts.gen_name_to_index); + const string result_names_code = GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); const string include_xla_data_proto = @@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; + // Number of variables for the compiled computation. + static constexpr size_t kNumVariables = {{VARIABLE_NUM}}; + // Byte size of each argument buffer. There are kNumArgs entries. static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); @@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { set_static_data_num_buffers(data, kNumBuffers); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); + set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); set_static_data_arg_names(data, StaticArgNames()); + set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( @@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} + // Array of names of each positional variable, terminated by nullptr. + static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}} + // Array of names of each positional result, terminated by nullptr. static const char** StaticResultNames() {{RESULT_NAMES_CODE}} @@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, + {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, @@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, + {"{{VARIABLE_NAMES_CODE}}", variable_names_code}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 6206f68faf9..babbd7fb2f5 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -156,17 +156,14 @@ static void CompareWithGoldenFile( // bazel test --test_strategy=local \ // third_party/tensorflow/compiler/aot:codegen_test const bool update_golden = false; - string golden_file_name; + string golden_file_name = + GetDataDependencyFilepath(tensorflow_relative_golden_file_name); if (update_golden) { - golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(), - tensorflow_relative_golden_file_name); TF_EXPECT_OK( WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); } - golden_file_name = - GetDataDependencyFilepath(tensorflow_relative_golden_file_name); string golden_file_contents; TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, &golden_file_contents)); @@ -220,10 +217,16 @@ TEST(CodegenTest, Golden) { {}, {BufferInfo::MakeTempBuffer(1), BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), - BufferInfo::MakeTempBuffer(2), + BufferInfo::MakeTempBuffer(1), BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), - BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, - 5, {})); + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4), + BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)}, + 11, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 1669e728d1a..af58ca233f0 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -55,14 +55,17 @@ namespace bar { // ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: -// arg bytes total: 104 -// arg bytes aligned: 192 +// arg bytes total: 392 +// arg bytes aligned: 576 // temp bytes total: 126 -// temp bytes aligned: 320 +// temp bytes aligned: 512 class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 2; + static constexpr size_t kNumArgs = 5; + + // Number of variables for the compiled computation. + static constexpr size_t kNumVariables = 3; // Byte size of each argument buffer. There are kNumArgs entries. static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { @@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { set_static_data_num_buffers(data, kNumBuffers); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); + set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); set_static_data_arg_names(data, StaticArgNames()); + set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( @@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { private: // Number of buffers for the compiled computation. - static constexpr size_t kNumBuffers = 6; + static constexpr size_t kNumBuffers = 12; static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { static const ::xla::cpu_function_runtime::BufferInfo kBufferInfos[kNumBuffers] = { ::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), -::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), -::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) }; return kBufferInfos; @@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const ::tensorflow::int32* ArgIndexToBufferIndex() { static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { -1, 3 +1, 3, 5, 7, 9 }; return kArgIndexToBufferIndex; } // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 5; + static constexpr size_t kResultIndex = 11; // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { @@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return kNames; } + // Array of names of each positional variable, terminated by nullptr. + static const char** StaticVariableNames() { + static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr}; + return kNames; + } + // Array of names of each positional result, terminated by nullptr. static const char** StaticResultNames() { static const char* kNames[] = {"myfetch", nullptr}; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index de58c7f8a87..d6d012dcc71 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm-c/Target.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -105,14 +107,18 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, .ValueOrDie(); xla::XlaComputation computation; if (flags.mlir_components == "Bridge") { - TF_RETURN_IF_ERROR( - ConvertGraphDefToXlaViaMlir(graph_def, config, &computation)); + TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir( + graph_def, config, &computation, flags.debug_info, + flags.debug_info_path_begin_marker)); } else if (flags.mlir_components.empty() || flags.mlir_components == "None") { TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config, client, &computation)); } else { return errors::Unknown("Unknown mlir_components ", flags.mlir_components); } + if (flags.quantize) { + TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation)); + } if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); @@ -166,6 +172,23 @@ static void InitializeTargets() { LLVMInitializeX86AsmPrinter(); } +// Replaces {{tag.type tag.name}} in the error message with tag_name. +// TODO(bixia): We currently only handlge tag.type == "node". +// +// In the error message, a graph node is represented as {{tag.type, tag.name}}, +// to allow a Python debugger to insert source information about the graph node. +// For example, a Python add expression may be represented as +// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate +// in tensorflow/python/framework/error_interpolation.py for more detail. +static std::string InterpolateErrorMessage(std::string message) { + // See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py + // Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix". + static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"}; + RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3"); + + return message; +} + Status Main(const MainFlags& flags) { absl::call_once(targets_init, &InitializeTargets); @@ -192,8 +215,13 @@ Status Main(const MainFlags& flags) { GraphDef graph_def; TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); CompileResult compile_result; - TF_RETURN_IF_ERROR( - CompileGraph(std::move(graph_def), config, flags, &compile_result)); + + Status status = + CompileGraph(std::move(graph_def), config, flags, &compile_result); + if (!status.ok()) { + return Status(status.code(), + InterpolateErrorMessage(status.error_message())); + } // Write output files. Env* env = Env::Default(); diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index e7040d12b8b..e8168bf706e 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -24,6 +24,13 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "Input GraphDef file. If the file ends in '.pbtxt' it is expected to " "be in the human-readable proto text format, otherwise it is expected " "to be in the proto binary format."}, + {"debug_info", &flags->debug_info, + "Graph debug info file. If the file ends in '.pbtxt' it is expected to " + "be in the human-readable proto text format, otherwise it is expected " + "to be in the proto binary format."}, + {"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker, + "If not none, only keep the file path in the debug information after the" + " marker. The default value is empty"}, {"config", &flags->config, "Input file containing Config proto. If the file ends in '.pbtxt' it " "is expected to be in the human-readable proto text format, otherwise " @@ -70,6 +77,8 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "Output session module proto."}, {"mlir_components", &flags->mlir_components, "The MLIR components to enable. Currently only Bridge is supported."}, + {"quantize", &flags->quantize, + "If set, quantization will be applied before HLO code generation."}, {"gen_name_to_index", &flags->gen_name_to_index, "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, {"gen_program_shape", &flags->gen_program_shape, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 451a0455977..96395c7501b 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -28,6 +28,8 @@ namespace tfcompile { struct MainFlags { string graph; + string debug_info; + string debug_info_path_begin_marker; string config; bool dump_fetch_nodes = false; string target_triple; @@ -40,6 +42,7 @@ struct MainFlags { string out_header; string out_session_module; string mlir_components; + bool quantize = false; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 6f438f0e271..0c44ed8bf37 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,11 +1,37 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( default_visibility = ["//visibility:private"], licenses = ["notice"], # Apache 2.0 ) +glob_lit_tests( + data = [":filecheck_test_utilities"], + driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "test_error_message.lit.pbtxt": ["no_oss"], # TODO(b/150957738): to be fixed on oss. + }, + test_file_exts = ["lit.pbtxt"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "filecheck_test_utilities", + testonly = True, + srcs = [ + "test_error_message.lit.pbtxt.config.pbtxt", + "test_error_message.lit.pbtxt.debug.pbtxt", + "test_error_message.lit.pbtxt.fake_py.debug", + ], + data = [ + "//tensorflow/compiler/aot:tfcompile", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) + # We disable some tfcompile tests in the open source build with the # "manual" tag to avoid making our OSS users build LLVM twice # (once for host and once for target). @@ -60,6 +86,7 @@ genrule( testonly = 1, outs = [ "test_graph_tfadd.pb", + "test_debuginfo_tfadd.pb", "test_graph_tfadd_with_ckpt.ckpt", "test_graph_tfadd_with_ckpt.pb", "test_graph_tfadd_with_ckpt_saver.ckpt", @@ -317,6 +344,7 @@ tf_library( testonly = 1, config = "test_graph_tfadd.config.pbtxt", cpp_class = "AddComp", + debug_info = "test_debuginfo_tfadd.pb", graph = "test_graph_tfadd.pb", include_standard_runtime_deps = False, mlir_components = "Bridge", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index a96ba0e6919..629239d6e4a 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_): array_ops.identity(updates, name='result') -def write_graph(build_graph, out_dir): +def export_debug_info(exported_graph): + """Exports debug information from a graph. + + Args: + exported_graph: A Graph that has been created by tracing a saveable view. + + Returns: + Corresponding GraphDebugInfo with traces for all ops in exported_graph. + """ + exported_operations = [] + for op in exported_graph.get_operations(): + exported_operations.append(('', op)) + return error_interpolation.create_graph_debug_info_def(exported_operations) + + +def write_graph(build_graph, out_dir, debug_info=False): """Build a graph using build_graph and write it out.""" g = ops.Graph() with g.as_default(): @@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir): with open(filename, 'wb') as f: f.write(six.ensure_binary(g.as_graph_def().SerializeToString())) + if debug_info: + filename_debuginfo = os.path.join( + out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__) + test_debuginfo = export_debug_info(g) + with open(filename_debuginfo, 'wb') as f: + f.write( + six.ensure_binary( + test_debuginfo.SerializeToString(deterministic=True))) + def main(_): control_flow_util.enable_control_flow_v2() - write_graph(tfadd, FLAGS.out_dir) + write_graph(tfadd, FLAGS.out_dir, debug_info=True) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) write_graph(tfassert_eq, FLAGS.out_dir) diff --git a/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt new file mode 100644 index 00000000000..5b05eb4b33d --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt @@ -0,0 +1,69 @@ +# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure +# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure + +# Checks the error message produced by tfcompile with mlir_component +# Checks that source debug information is used in the output error message and +# the node x_y_sum = Add +# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)' +# CHECK: math_ops.add(x, y, name='x_y_sum') +# CHECK: build_graph(out_dir) + +# Checks the error message produced by tfcompile without mlir_component +# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3] +# OLD: x_y_sum + +node: { + name: "x" + op: "Placeholder" + attr: { + key: "shape" + value: { + shape: { + dim: { + size: -1 + } + } + } + } + attr: { + key: "dtype" + value: { + type: DT_INT32 + } + } +} +node: { + name: "y" + op: "Placeholder" + attr: { + key: "shape" + value: { + shape: { + dim: { + size: -1 + } + } + } + } + attr: { + key: "dtype" + value: { + type: DT_INT32 + } + } +} +node: { + name: "x_y_sum" + op: "Add" + input: "x" + input: "y" + attr: { + key: "T" + value: { + type: DT_INT32 + } + } +} +versions: { + producer: 321 +} diff --git a/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.config.pbtxt b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.config.pbtxt new file mode 100644 index 00000000000..2694e67da06 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 2 } + } +} +feed { + id { node_name: "y" } + shape { + dim { size: 3 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.debug.pbtxt b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.debug.pbtxt new file mode 100644 index 00000000000..7acc8287950 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.debug.pbtxt @@ -0,0 +1,28 @@ +files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug" +traces: { + key: "x@" + value: { + file_line_cols: { + line: 1 + } + } +} +traces: { + key: "x_y_sum@" + value: { + file_line_cols: { + line: 3 + } + file_line_cols: { + line: 4 + } + } +} +traces: { + key: "y@" + value: { + file_line_cols: { + line: 2 + } + } +} diff --git a/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug new file mode 100644 index 00000000000..083e8d522d5 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug @@ -0,0 +1,4 @@ + x = value + y = value + math_ops.add(x, y, name='x_y_sum') + build_graph(out_dir) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 80606b6c5ee..35a054a1aab 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -26,6 +26,7 @@ def tf_library( name, graph, config, + debug_info = None, freeze_checkpoint = None, freeze_saver = None, cpp_class = None, @@ -191,12 +192,15 @@ def tf_library( mlir_flag = "--mlir_components=" + mlir_components + srcs = [tfcompile_graph, config] + debug_info_flag = "" + if debug_info: + srcs.append(debug_info) + debug_info_flag = " --debug_info=$(location " + debug_info + ")" + native.genrule( name = ("gen_" + name), - srcs = [ - tfcompile_graph, - config, - ], + srcs = srcs, outs = [ header_file, metadata_object_file, @@ -206,6 +210,7 @@ def tf_library( "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + + debug_info_flag + " --config=$(location " + config + ")" + " --entry_point=" + ep + " --cpp_class=" + cpp_class + @@ -237,10 +242,7 @@ def tf_library( session_module_pb = name + "_session_module.pb" native.genrule( name = (name + "_session_module"), - srcs = [ - tfcompile_graph, - config, - ], + srcs = srcs, outs = [ session_module_pb, ], @@ -248,6 +250,7 @@ def tf_library( "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + + debug_info_flag + " --config=$(location " + config + ")" + " --entry_point=" + ep + " --cpp_class=" + cpp_class + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index d027bae5d04..f0cf8f2ded9 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -65,6 +65,7 @@ int main(int argc, char** argv) { flags.out_metadata_object = "out_helper.o"; flags.out_header = "out.h"; flags.entry_point = "entry"; + flags.debug_info_path_begin_marker = ""; std::vector flag_list; AppendMainFlags(&flag_list, &flags); @@ -81,12 +82,10 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " - "other than flags\n\n" - << usage; + "other than flags. See --help.\n\n"; tensorflow::Status status = tensorflow::tfcompile::Main(flags); if (status.code() == tensorflow::error::INVALID_ARGUMENT) { - std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" - << usage; + std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"; return 1; } else { TF_QCHECK_OK(status); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index acbd2d27a45..f71331af0df 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 618fafe759b..61d0c0de35f 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -18,6 +18,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 6ec9b5a477a..6c5e3a745e2 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -368,14 +368,20 @@ bool GraphCycles::CanContractEdge(int32 a, int32 b) { return !reachable; } -bool GraphCycles::ContractEdge(int32 a, int32 b) { +absl::optional GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); if (IsReachableNonConst(a, b)) { // Restore the graph to its original state. InsertEdge(a, b); - return false; + return absl::nullopt; + } + + if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() > + rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) { + // Swap "a" and "b" to minimize copying. + std::swap(a, b); } Node* nb = rep_->nodes_[b]; @@ -399,7 +405,8 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) { InsertEdge(y, a); } - return true; + // Note, if the swap happened it might be what originally was called "b". + return a; } absl::Span GraphCycles::Successors(int32 node) const { diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index bbf61016fb3..3e20c4e641c 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -40,6 +40,7 @@ limitations under the License. // FindPath() is linear in the size of the graph. // The current implementation uses O(|V|+|E|) space. +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -80,11 +81,11 @@ class GraphCycles { // Return whether there is an edge directly from source_node to dest_node. bool HasEdge(int32 source_node, int32 dest_node) const; - // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is - // removed from the graph, and edges to/from 'b' are replaced with edges - // to/from 'a'. If contracting the edge would create a cycle, does nothing - // and returns false. - bool ContractEdge(int32 a, int32 b); + // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of + // the nodes is removed from the graph, and edges to/from it are added to + // the remaining one, which is returned. If contracting the edge would create + // a cycle, does nothing and return no value. + absl::optional ContractEdge(int32 a, int32 b); // Return true if can contract edge, otherwise return false. bool CanContractEdge(int32 a, int32 b); diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index 274f5938a12..5b7eec19e27 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include #include #include #include @@ -479,19 +480,21 @@ TEST_F(GraphCyclesTest, ContractEdge) { ASSERT_TRUE(AddEdge(2, 4)); ASSERT_TRUE(AddEdge(3, 4)); - EXPECT_FALSE(g_.ContractEdge(1, 3)); + EXPECT_FALSE(g_.ContractEdge(1, 3).has_value()); CHECK(g_.CheckInvariants()); EXPECT_TRUE(g_.HasEdge(1, 3)); - EXPECT_TRUE(g_.ContractEdge(1, 2)); + // Node (2) has more edges. + EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2); CHECK(g_.CheckInvariants()); - EXPECT_TRUE(g_.HasEdge(1, 3)); - EXPECT_TRUE(g_.HasEdge(1, 4)); + EXPECT_TRUE(g_.HasEdge(2, 3)); + EXPECT_TRUE(g_.HasEdge(2, 4)); EXPECT_TRUE(g_.HasEdge(3, 4)); - EXPECT_TRUE(g_.ContractEdge(1, 3)); + // Node (2) has more edges. + EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2); CHECK(g_.CheckInvariants()); - EXPECT_TRUE(g_.HasEdge(1, 4)); + EXPECT_TRUE(g_.HasEdge(2, 4)); } TEST_F(GraphCyclesTest, CanContractEdge) { @@ -527,3 +530,26 @@ static void BM_StressTest(int iters, int num_nodes) { } } BENCHMARK(BM_StressTest)->Range(2048, 1048576); + +static void BM_ContractEdge(int iters, int num_nodes) { + while (iters-- > 0) { + tensorflow::testing::StopTiming(); + tensorflow::GraphCycles g; + std::vector nodes; + nodes.reserve(num_nodes); + for (int i = 0; i < num_nodes; i++) { + nodes.push_back(g.NewNode()); + } + // All edges point toward the last one. + for (int i = 0; i < num_nodes - 1; ++i) { + g.InsertEdge(nodes[i], nodes[num_nodes - 1]); + } + + tensorflow::testing::StartTiming(); + int node = num_nodes - 1; + for (int i = 0; i < num_nodes - 1; ++i) { + node = g.ContractEdge(nodes[i], node).value(); + } + } +} +BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 2b58a9260ba..c64f4d32535 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -172,8 +172,9 @@ class XlaExecutableClosureStore { private: mutex mutex_; - int64 key_counter_ GUARDED_BY(mutex_); - absl::flat_hash_map closures_ GUARDED_BY(mutex_); + int64 key_counter_ TF_GUARDED_BY(mutex_); + absl::flat_hash_map closures_ + TF_GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 836cb7e6862..112408226a8 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -165,7 +165,8 @@ class XlaCompileOp : public OpKernel { // error when compiling the cluster this _XlaCompile is supposed to compile. // If `cannot_compile_cluster_` is true then we avoid compiling this cluster // on any future calls to _XlaCompile. - bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) = + false; mutex cannot_compile_cluster_mu_; }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 08dc1b13db6..2a29527bfef 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -161,6 +161,11 @@ class MarkForCompilationPassImpl { // The ID of the cluster as represented in `cycles_graph_`. int cycles_graph_node_id() const { return cycles_graph_node_id_; } + // Sets the ID of the cluster as represented in `cycles_graph_`. + void set_cycles_graph_node_id(int cycles_graph_node_id) { + cycles_graph_node_id_ = cycles_graph_node_id; + } + // The size of the cluster excluding constant and identity nodes. int effective_cluster_size() const { return effective_cluster_size_; } @@ -381,14 +386,16 @@ class MarkForCompilationPassImpl { // R, B} cluster. string DescribePotentialCycle(int from, int to); - // Merge the clusters `cluster_from` and `cluster_to`. After this step the - // larger combined cluster is represented by `cluster_from`'s ID in - // `cycles_graph_`. + // Merge the clusters `cluster_from` and `cluster_to`. After this step the + // larger combined cluster is represented by `cluster_from`, but can have + // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on + // which way will require less operations. bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { int from = cluster_from->cycles_graph_node_id(); int to = cluster_to->cycles_graph_node_id(); - if (!cycles_graph_.ContractEdge(from, to)) { + auto optional_merged_node = cycles_graph_.ContractEdge(from, to); + if (!optional_merged_node.has_value()) { VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_) << " -> " << cluster_to->DebugString(*graph_) << " because contracting the edge would create a cycle via " @@ -398,6 +405,8 @@ class MarkForCompilationPassImpl { // Merge the clusters. cluster_from->Merge(cluster_to); + // Update `cycle_graph_`'s ID. + cluster_from->set_cycles_graph_node_id(optional_merged_node.value()); // Merge the UnionFind. cluster_for_node_[from].Merge(&cluster_for_node_[to]); @@ -1911,6 +1920,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "LinSpace", "ListDiff", "LogMatrixDeterminant", + "LowerBound", "MatMul", "MatrixBandPart", "MatrixDiag", @@ -2037,6 +2047,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "TensorScatterUpdate", "TridiagonalSolve", "TruncatedNormal", + "UpperBound", "UnsortedSegmentMax", "UnsortedSegmentMin", "UnsortedSegmentProd", diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc index a1ea6a6bf8e..ae28bf10fb2 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_listener.cc @@ -18,13 +18,15 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { namespace { // The list of all registered `XlaActivityListener`s. struct XlaActivityListenerList { absl::Mutex mutex; - std::vector> listeners GUARDED_BY(mutex); + std::vector> listeners + TF_GUARDED_BY(mutex); }; void FlushAllListeners(); diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index acac2f7d055..6333499b0c8 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -50,7 +50,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { GraphCycles cycles; TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); - EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); + EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id())); } TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { @@ -69,7 +69,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { GraphCycles cycles; TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); - EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); + EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id())); } TEST(CreateCycleDetectionGraph, ReachingEnterExit) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 03a9a3ad3a4..5540fee7276 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/util.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" @@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) { execution_count < kMinExecutionsPerCompile * compile_count; } +// Creates a simple graph using the specified op as the only op apart from the +// arg and retval nodes. +static xla::StatusOr> CreateGraph( + const NodeDef& node_def, absl::Span args, + absl::Span result_types) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(node_def, &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < args.size(); ++i) { + Node* node; + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == XlaCompiler::Argument::kResource + ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < result_types.size(); ++i) { + Node* node; + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) + .Input(main_node, i) + .Attr("T", result_types[i]) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + FixupSourceAndSinkEdges(graph.get()); + return graph; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, @@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp( for (int i = 0; i < result_dtypes.size(); ++i) { result_dtypes[i] = ctx->expected_output_dtype(i); } - return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), - args, result_dtypes, result); + + const NodeDef& node_def = ctx->op_kernel().def(); + TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index b3653a2006a..83a0bda97d5 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -151,19 +151,19 @@ class XlaCompilationCache : public ResourceBase { int64 request_count = 0; // Did compilation succeed? - Status compilation_status GUARDED_BY(mu); + Status compilation_status TF_GUARDED_BY(mu); // Output of the XlaCompiler. - XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu); + XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); // The XLA executable compiled from . May be null if no // executable has been built. - std::unique_ptr executable GUARDED_BY(mu); + std::unique_ptr executable TF_GUARDED_BY(mu); }; mutex compile_cache_mu_; absl::flat_hash_map, Signature::Hash> cache_ - GUARDED_BY(compile_cache_mu_); + TF_GUARDED_BY(compile_cache_mu_); struct ClusterCompileStats { // Number of times the cluster has been (re-)compiled. @@ -185,7 +185,7 @@ class XlaCompilationCache : public ResourceBase { // Maps cluster names to compilation statistics for said cluster. absl::flat_hash_map cluster_compile_stats_ - GUARDED_BY(cluster_compile_stats_mu_); + TF_GUARDED_BY(cluster_compile_stats_mu_); // The number of times a lazy compilation must be requested for a specific // signature before we attempt to compile it. diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 830aaf74186..0cc462678b1 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -83,7 +83,7 @@ class XlaDeviceAllocatorState { std::unordered_map, std::unique_ptr, hash>> - allocators_ GUARDED_BY(allocator_mutex_); + allocators_ TF_GUARDED_BY(allocator_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState); }; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 546df476d7f..30f9a99e36a 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -137,7 +137,7 @@ class XlaDevice : public LocalDevice { ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override - LOCKS_EXCLUDED(mu_); + TF_LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; @@ -145,18 +145,18 @@ class XlaDevice : public LocalDevice { void Sync(const DoneCallback& done) override; Status TryGetDeviceContext(DeviceContext** out_context) override - LOCKS_EXCLUDED(mu_); + TF_LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override LOCKS_EXCLUDED(mu_); + Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); // Allocate tensor on fast memory space. This is only applied to the new TPU // hardware which has faster read/write memory. If the hardware doesn't // have such memory space, we fallback to the ordinary memory space. Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) LOCKS_EXCLUDED(mu_); + Tensor* tensor) TF_LOCKS_EXCLUDED(mu_); const Metadata& metadata() { return xla_metadata_; } @@ -166,34 +166,35 @@ class XlaDevice : public LocalDevice { // // TODO(b/111859745): The Eager context needs to call this method to recover // from failures. - Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra // information for GPU and TPU devices. - Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to return 'sync_on_completion' for // AllowsSyncOnCompletion(). - void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); - bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + void SetAllowsSyncOnCompletion(bool sync_on_completion) + TF_LOCKS_EXCLUDED(mu_); + bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); // Installs an error handling callback when RefreshStatus sees !status.ok(). void SetHandleDeviceErrorCallback(std::function callback); - Status RefreshStatus() override LOCKS_EXCLUDED(mu_); + Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); private: xla::StatusOr GetOrCreateClient() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) - EXCLUSIVE_LOCKS_REQUIRED(mu_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, std::shared_ptr* stream, bool* stream_was_changed) - EXCLUSIVE_LOCKS_REQUIRED(mu_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Return a pair of device context, the second one is fast_mem device context. xla::StatusOr> - GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); @@ -218,13 +219,13 @@ class XlaDevice : public LocalDevice { // Intra-op threads to spawn (from SessionOptions). const int intra_op_parallelism_threads_; // Memory allocator associated with this device. - Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. + Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - std::shared_ptr stream_ GUARDED_BY(mu_); + std::shared_ptr stream_ TF_GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_device stream or borrowing a stream @@ -232,36 +233,36 @@ class XlaDevice : public LocalDevice { const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); + std::shared_ptr host_to_device_stream_ TF_GUARDED_BY(mu_); // If use_multiple_streams_, transfers between different devices are performed // using these streams. std::vector> device_to_device_streams_ - GUARDED_BY(mu_); + TF_GUARDED_BY(mu_); const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // The device context accessed by all users of the XlaDevice, set by calls to // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is // also filled in to that struct. XlaDeviceContext is a ref-counted object. - XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; + XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr; // The device context will allocate memory on fast memory space on TPU. // XlaDeviceContext is a ref-counted object. - XlaDeviceContext* fast_mem_device_context_ GUARDED_BY(mu_) = nullptr; + XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr; // Holds extra information for GPU and TPU devices, e.g. the device context. - bool use_gpu_device_info_ GUARDED_BY(mu_) = false; - std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); + bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false; + std::unique_ptr gpu_device_info_ TF_GUARDED_BY(mu_); // Thread pool used for running closures std::unique_ptr thread_pool_; // True if the device allows XlaDevice::Sync to be called on completion // regardless of status. - bool sync_on_completion_ GUARDED_BY(mu_) = true; + bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; // A callback that will be invoked when RefreshStatus sees a status error. - std::function device_error_callback_ GUARDED_BY(mu_); + std::function device_error_callback_ TF_GUARDED_BY(mu_); // Set of devices to use. This controls which of the devices on the given // platform will have resources allocated. For GPUs this will be diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6871f7ec614..e8df09c7b4d 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/stream_executor/platform/port.h" diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index b90a3ad2e16..05d8dfa7556 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -117,7 +117,7 @@ class XlaDeviceContext : public DeviceContext { bool use_fast_mem_; absl::Mutex mu_; - int next_stream_ GUARDED_BY(mu_) = 0; + int next_stream_ TF_GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 81d63d299ee..511e0f1451a 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -18,7 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ -#include "absl/base/thread_annotations.h" #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/device_memory_allocator.h" namespace tensorflow { @@ -102,7 +102,7 @@ class VariableInfo { // `variables` is allowed to contain instances that don't track a resource // variable (i.e. variables[i].var() can be null for some i). Status LockVariables(absl::Span variables) - EXCLUSIVE_LOCK_FUNCTION(); + TF_EXCLUSIVE_LOCK_FUNCTION(); // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 8a4eb7493be..7f7d97e3b3f 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -122,7 +122,7 @@ class XlaTensor { std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ TF_GUARDED_BY(mu_); mutex mu_; }; diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 90c60a85ba2..2ed1c274f75 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -74,12 +74,15 @@ cc_library( "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:lhlo_copy_removal", "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", + "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops", "//tensorflow/compiler/mlir/xla:xla_dialect_registration", "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/compiler/mlir/xla:xla_lower", @@ -100,6 +103,38 @@ cc_library( ], ) +cc_library( + name = "mlir_graph_optimization_pass", + srcs = ["mlir_graph_optimization_pass.cc"], + hdrs = ["mlir_graph_optimization_pass.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:device_util", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/core:core_cpu", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_graph_optimization_pass_registration", + srcs = [ + "mlir_graph_optimization_pass_registration.cc", + ], + deps = [ + ":mlir_graph_optimization_pass", + "//tensorflow/core:core_cpu", + ], + alwayslink = 1, +) + tf_cc_binary( name = "tf-opt", deps = [ diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 5e6348e3ac0..c917af71f92 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -30,7 +30,8 @@ filegroup( "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", ], ) @@ -221,19 +222,21 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", - "@llvm-project//llvm:support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - # TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface. - "@llvm-project//mlir:Transforms", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -325,8 +328,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", @@ -436,7 +439,7 @@ genrule( srcs = [ "ir/tfl_ops.td", "ir/tfl_op_interfaces.td", - "@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", ], outs = [ @@ -516,6 +519,7 @@ cc_library( "@com_google_absl//absl/strings", "@flatbuffers", "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformUtils", ], @@ -580,7 +584,7 @@ cc_library( "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", - "//tensorflow/lite/tools/versioning:op_version", + "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -697,11 +701,12 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", - "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -725,6 +730,7 @@ cc_library( ":tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", @@ -734,11 +740,10 @@ cc_library( "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", - "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], diff --git a/tensorflow/compiler/mlir/lite/README.md b/tensorflow/compiler/mlir/lite/README.md index 224727621d6..b9c58b28a79 100644 --- a/tensorflow/compiler/mlir/lite/README.md +++ b/tensorflow/compiler/mlir/lite/README.md @@ -1,9 +1,9 @@ -# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite. +# The new [MLIR](https://github.com/llvm/llvm-project/tree/master/mlir) based +TensorFlow to TensorFlow Lite converter This directory contains: -1. Experimental code for the new TF-Lite convertor. -2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir). +1. MLIR dialects, transformation passes and utilities for TensorFlow Lite. ## API: @@ -11,7 +11,8 @@ The API for converting TensorFlow models to TensorFlow Lite will be through `tf.lite.TFLiteConverter`. All the conversion code is open sourced, and the API will be integrated soon. -### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes: +### The conversion process from TensorFlow to TensorFlow Lite includes the +following major passes: - Import from GraphDef, in .pb or .pbtxt format, into MLIR. - Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect. @@ -28,3 +29,6 @@ TensorFlow Lite models). - The Export pass writes out TensorFlow Lite FlatBuffer format. This pass operates on MLIR TensorFlow Lite dialect and is simple/direct translation. +See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +for the full list of MLIR passes for conversion from TensorFlow to +TensorFlow Lite. diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index b14041e8067..322aeadfa37 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -34,9 +34,9 @@ struct PassConfig { quant_specs(std::move(specs)), skip_control_dialect(false), form_clusters(false), - inline_functions(true), unfold_batch_matmul(true), - legalize_tf_while(true) {} + legalize_tf_while(true), + shape_inference(false) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -56,9 +56,6 @@ struct PassConfig { // are formed by grouping consecutive ops of the same device, under a // `tf_device.launch` op. bool form_clusters; - // Inline function calls within the main function in the MLIR module, prior - // to legalization to TFLite. - bool inline_functions; // if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set // of tfl.fully_connected ops. bool unfold_batch_matmul; @@ -66,6 +63,8 @@ struct PassConfig { // Note: This is staging step and will be removed. // TODO(b/137395003): Remove post switching legalization. bool legalize_tf_while; + // Whether to do shape inference. + bool shape_inference; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 02d9ef45591..8ecff8757b7 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -119,6 +119,12 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper, // conversion generation and so the simplicity was chosen over the // flexibility. StringRef arg_name = arg_values->getArgNameStr(i); + // Skip any "intermiadiateXXX" attribute as they are specially handled + // in the exporter. They are special because though they are attributes + // in the MLIR they are expressed as tensors in the flatbuffer instead + // of option. + if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate") + continue; os << formatv( " auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n", arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName()); @@ -164,17 +170,24 @@ static void EmitOperatorBuilders(const std::vector &defs, for (const auto *def : defs) { StringRef op_name = def->getName().drop_front(4); + const bool has_intermediates = op_name == "LSTMOp"; // Signature os << "static flatbuffers::Offset " << GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name << " tflOp, uint32_t opcode_index, " << "const std::vector& operands," << "const std::vector& results," + << (has_intermediates ? "const std::vector& intermediate_index," + : "") << "flatbuffers::FlatBufferBuilder *fbb) {\n"; // Inputs & outputs os << " auto inputs = fbb->CreateVector(operands);\n" " auto outputs = fbb->CreateVector(results);\n\n"; + // Intermediates for LSTM. + if (has_intermediates) { + os << " auto intermediates = fbb->CreateVector(intermediate_index);\n"; + } // Build the FlatBuffer operator os << " return tflite::CreateOperator(\n" @@ -191,9 +204,9 @@ static void EmitOperatorBuilders(const std::vector &defs, // Only builtin ops' builders are auto-generated. custom_options are only // used by custom or flex ops and those ops are handled manually. os << " /*custom_options=*/0, " - "tflite::CustomOptionsFormat_FLEXBUFFERS,\n" - " /*mutating_variable_inputs=*/0);\n" - "}\n\n"; + << "tflite::CustomOptionsFormat_FLEXBUFFERS,\n" + << " /*mutating_variable_inputs=*/0" + << (has_intermediates ? ", intermediates" : "") << ");\n}\n\n"; } } @@ -244,6 +257,7 @@ static void EmitGetBuiltinOpCode(const std::vector &defs, // uint32_t opcode_index, // const std::vector& operands, // const std::vector& results, +// const std::vector& intermediates, // flatbuffers::FlatBufferBuilder *fbb); static void EmitBuildOperator(const std::vector &defs, raw_ostream *ostream) { @@ -255,6 +269,7 @@ static void EmitBuildOperator(const std::vector &defs, "uint32_t opcode_index, " "const std::vector& operands," "const std::vector& results," + "const std::vector& intermediates," "flatbuffers::FlatBufferBuilder *fbb) {\n"; for (const auto *def : defs) { @@ -264,7 +279,8 @@ static void EmitBuildOperator(const std::vector &defs, os << " if (auto tflOp = llvm::dyn_cast(op))\n" << " return " << GetOperatorBuilderName(def->getName()) - << "(tflOp, opcode_index, operands, results, fbb);\n"; + << "(tflOp, opcode_index, operands, results, " + << (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n"; } os << " return llvm::None;\n" @@ -307,6 +323,10 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, if (!arg_def) continue; if (arg_def->getDef()->isSubClassOf(attr_type)) { StringRef arg_name = arg_values->getArgNameStr(i); + // Already handle this case in flatbuffer_import.cc. + if (option_name == "LSTMOptions" && + arg_name.take_back(12) == "intermediate") + continue; StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName(); os << formatv( " attributes.emplace_back(builder.getNamedAttr(\"{0}\"," diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 6753ab9e728..29233f86e4a 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -547,6 +547,7 @@ bool IsCustomOp(const std::string& op_name) { // TODO(krzysd) Handle function calls StatusOr ConvertOp( const tflite::OperatorT& op, const std::vector& vals_map, + const std::vector& intermediate_types, Value optional_arg_marker, const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, @@ -608,6 +609,28 @@ StatusOr ConvertOp( if (op_name == "tfl.lstm") { // TODO(b/147587779): add the right region if region is empty. op_state.addRegion(); + if (!op.intermediates.empty()) { + if (op.intermediates.size() != 5) { + auto err = errors::InvalidArgument( + "operator has intermediate tensors but the number of them is not " + "five."); + return emitError(loc, err.ToString()), err; + } + // Create intermediate value + + const llvm::SmallVector kIntermediateNames = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (auto type_and_name : + llvm::zip(intermediate_types, kIntermediateNames)) { + mlir::TypeAttr type_attr = + mlir::TypeAttr::get(std::get<0>(type_and_name)); + auto named_attr = + builder.getNamedAttr(std::get<1>(type_and_name), type_attr); + op_state.addAttribute(named_attr.first, named_attr.second); + } + } } llvm::SmallVector attrs; @@ -893,6 +916,18 @@ StatusOr ConvertSubgraph( } } + // Intermediate tensors for tfl.lstm are used to carry quantization range + // in their types, so we only need and extract their types. + std::vector intermediate_types; + intermediate_types.reserve(5); + for (auto intermediate : op->intermediates) { + TF_ASSIGN_OR_RETURN( + auto type, GetTensorType(*subgraph.tensors[intermediate], builder, + /*shapeless_are_scalars=*/true, + /*is_constant=*/true)); + intermediate_types.emplace_back(type); + } + // The NameLoc corresponding to the name of the first output tensor auto op_loc = op->outputs.empty() @@ -902,8 +937,8 @@ StatusOr ConvertSubgraph( // to a valid Value TF_ASSIGN_OR_RETURN( auto* mlir_op, - ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names, - func_names, subgraph.tensors, op_loc, op_builder)); + ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker, + op_names, func_names, subgraph.tensors, op_loc, op_builder)); // Add the results to the value maps. There are two cases: 1. the result // tensor does not have min/max values, the original op result is used diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index fdc0fd81f8f..4e8e3f6424e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -44,6 +44,7 @@ llvm::Optional GetBuiltinOpCode(Operation *mlir_op); llvm::Optional> CreateFlatBufferOperator( Operation *mlir_op, uint32_t opcode_index, const std::vector &operands, const std::vector &results, + const std::vector &intermediates, flatbuffers::FlatBufferBuilder *fbb); // Populates the array of mlir::NamedAttributes corresponding to the given diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index ac20ab68eaa..9e9330e2c96 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -43,6 +43,7 @@ limitations under the License. #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project @@ -75,6 +76,7 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/version.h" using llvm::dyn_cast; @@ -179,8 +181,6 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_FLOAT16; case mlir::TF::TensorFlowTypes::STRING: return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::UINT8: - return tflite::TensorType_UINT8; case mlir::TF::TensorFlowTypes::QUINT8: return tflite::TensorType_UINT8; case mlir::StandardTypes::Complex: { @@ -196,7 +196,8 @@ static StatusOr GetTFLiteType(Type type, case 1: return tflite::TensorType_BOOL; case 8: - return tflite::TensorType_INT8; + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; case 16: return tflite::TensorType_INT16; case 32: @@ -404,6 +405,11 @@ class Translator { // and returns llvm::None on failure. Optional> BuildBuffer(Operation* inst); + // Build TFLite tensor from the given type. This function is for tfl.lstm + // intermediates, which should have UniformQuantizedType. + Optional> BuildTensorFromType( + mlir::Type type, const std::string& name); + // Builds TFLite tensor from the given value. `buffer_idx` is index of the // corresponding buffer. Emits error and returns llvm::None on failure. Optional> BuildTensor(Value value, @@ -469,7 +475,8 @@ class Translator { // tensor indices. Emits an error and returns llvm::None on failure. Optional> BuildOperator( Operation* inst, const std::vector& operands, - const std::vector& results); + const std::vector& results, + const std::vector& intermediates); // Build a subgraph with a given name out of the region either corresponding // to a function's body or while op. @@ -581,6 +588,34 @@ Optional> Translator::BuildBuffer( return tflite::CreateBuffer(builder_, buffer_data); } +Optional> Translator::BuildTensorFromType( + mlir::Type type, const std::string& name) { + auto tensor_type = type.cast(); + + if (!tensor_type.hasStaticShape()) { + return llvm::None; + } + llvm::ArrayRef shape_ref = tensor_type.getShape(); + std::vector shape(shape_ref.begin(), shape_ref.end()); + + auto element_type = tensor_type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); + BufferOffset q_params; + auto qtype = element_type.dyn_cast(); + if (!qtype) { + return llvm::None; + } + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + /*buffer=*/0, builder_.CreateString(name), q_params, + /*is_variable=*/false); +} + Optional> Translator::BuildTensor( Value value, const std::string& name, unsigned buffer_idx) { auto type = value.getType().cast(); @@ -933,7 +968,8 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name, Optional> Translator::BuildOperator( Operation* inst, const std::vector& operands, - const std::vector& results) { + const std::vector& results, + const std::vector& intermediates) { const auto* dialect = inst->getDialect(); if (!dialect) { inst->emitOpError("dialect is not registered"); @@ -986,7 +1022,7 @@ Optional> Translator::BuildOperator( std::string op_name = inst->getName().getStringRef().str(); uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, - results, &builder_); + results, intermediates, &builder_); if (!offset) { inst->emitOpError("is not a supported TFLite op"); } @@ -1171,6 +1207,29 @@ Optional> Translator::BuildSubGraph( bool failed_once = false; for (auto& inst : bb) { if (inst.isKnownTerminator()) break; + std::vector intermediates; + // Build intermediate tensors for tfl.lstm and insert these tensors into + // flatbuffer. + if (llvm::isa(inst)) { + std::vector intermediate_names = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (const std::string& intermediate : intermediate_names) { + auto intermediate_attr = inst.getAttr(intermediate); + if (auto attr = intermediate_attr.dyn_cast_or_null()) { + Type qtype = attr.getValue(); + auto tensor_or = BuildTensorFromType( + qtype, name_mapper_.GetUniqueName(intermediate).str()); + if (!tensor_or.hasValue()) { + continue; + } else { + intermediates.push_back(tensors.size()); + tensors.push_back(tensor_or.getValue()); + } + } + } + } for (auto val : inst.getResults()) { std::string name = UniqueName(val); @@ -1195,7 +1254,8 @@ Optional> Translator::BuildSubGraph( results.push_back(tensor_index_map.lookup(result)); } - if (auto tfl_operator = BuildOperator(&inst, operands, results)) + if (auto tfl_operator = + BuildOperator(&inst, operands, results, intermediates)) operators.push_back(*tfl_operator); else failed_once = true; @@ -1230,27 +1290,58 @@ BufferOffset Translator::BuildMetadata(StringRef name, Optional>> Translator::CreateMetadataVector() { auto dict_attr = module_.getAttrOfType("tfl.metadata"); - if (!dict_attr) return VectorBufferOffset>(); - std::vector> metadata; - for (const auto& named_attr : dict_attr) { - StringRef name = named_attr.first; - mlir::Attribute attr = named_attr.second; - if (auto content = attr.dyn_cast()) { - metadata.push_back(BuildMetadata(name, content.getValue())); - } else { - module_.emitError( - "all values in tfl.metadata's dictionary key-value pairs should be " - "string attributes"); - return llvm::None; + if (dict_attr) { + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } } } + // Runtime version string is generated after we update the op + // versions. Here we put a 16-byte dummy string as a placeholder. We choose + // 16-byte because it's the alignment of buffers in flatbuffer, so it won't + // cause any waste of space if the actual string is shorter than 16 bytes. + metadata.push_back( + BuildMetadata("min_runtime_version", std::string(16, '\0'))); return builder_.CreateVector(metadata); } +bool UpdateEntryFunction(ModuleOp module) { + if (module.lookupSymbol("main") != nullptr) { + // We already have an entry function. + return true; + } + + int entry_func_count = 0; + FuncOp entry_func = nullptr; + for (auto fn : module.getOps()) { + auto attrs = fn.getAttrOfType("tf.entry_function"); + if (attrs && !attrs.empty()) { + entry_func_count++; + entry_func = fn; + } + } + + // We should have one & only have one entry function. + if (entry_func_count != 1) return false; + + // Update the entry func to main. + entry_func.setName("main"); + return true; +} + Optional Translator::Translate( ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + if (!UpdateEntryFunction(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None; Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper); @@ -1334,6 +1425,7 @@ Optional Translator::TranslateInternal() { builder_.CreateVector(buffers_), metadata_buffer, *metadata); tflite::FinishModelBuffer(builder_, model); tflite::UpdateOpVersion(builder_.GetBufferPointer()); + tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); // Return serialized string for the built FlatBuffer. return std::string(reinterpret_cast(builder_.GetBufferPointer()), diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 83e372e5732..5f8e9c35b94 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project #include "mlir/Transforms/InliningUtils.h" // TF:llvm-project #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -66,13 +67,29 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface { } }; +struct TensorFlowLiteOpFolderDialectInterface + : public OpFolderDialectInterface { + using OpFolderDialectInterface::OpFolderDialectInterface; + + // Registered hook to check if the given region, which is attached to an + // operation that is *not* isolated from above (i.e. no internal regions + // reference values defined in an enclosing region), should be used when + // materializing constants. + // In the TFLite dialect we materialize inside a while regions as slightly + // more efficient computationally. + bool shouldMaterializeInto(Region *region) const final { + return isa(region->getParentOp()); + } +}; + TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context) : Dialect(/*name=*/"tfl", context) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" >(); - addInterfaces(); + addInterfaces(); } //===----------------------------------------------------------------------===// @@ -1269,6 +1286,20 @@ static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) { "UnidirectionalSequenceLSTMOp expected to have two stateful operands"); } +//===----------------------------------------------------------------------===// +// BidirectionalSequenceLSTMOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BidirectionalSequenceLSTMOp op) { + auto operands = op.GetStatefulOperands(); + if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 && + operands[2] == 37 && operands[3] == 38) { + return success(); + } + return op.emitError( + "BidirectionalSequenceLSTMOp expected to have four stateful operands"); +} + //===----------------------------------------------------------------------===// // UnidirectionalSequenceRNNOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 1e74a8c1a9e..cfe18a218bc 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -25,9 +25,11 @@ limitations under the License. #include "mlir/IR/Dialect.h" // TF:llvm-project #include "mlir/IR/OpImplementation.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project -#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -39,6 +41,8 @@ class TensorFlowLiteDialect : public Dialect { public: explicit TensorFlowLiteDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "tfl"; } + // Registered hook to materialize a constant operation from a given attribute // value with the desired resultant type. Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6c5981359b3..5624c7e2b73 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -19,7 +19,8 @@ limitations under the License. #define TFL_OPS include "mlir/IR/OpBase.td" -include "mlir/Transforms/LoopLikeInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffects.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td" @@ -47,13 +48,6 @@ def TFL_Str : Type()">, "TFLite string type">, BuildableType<"getType()">; -//===----------------------------------------------------------------------===// -// TFLite dialect uint8 type - uses the TF uint8 type as implementation -//===----------------------------------------------------------------------===// -def TFL_Uint8 : Type()">, - "TFLite uint8 type">, - BuildableType<"getType()">; - //===----------------------------------------------------------------------===// // TFLite dialect quint8 type - uses the TF quint8 type as implementation //===----------------------------------------------------------------------===// @@ -141,7 +135,8 @@ class TFL_VariadicTensorOf allowedRuntimeTypes, Variadic>, TFL_RuntimeType>>; -def TFL_Int32Or64 : IntOfWidths<[32, 64]>; +def TFL_Uint8 : UI<8>; +def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>; def TFL_BoolTensor : TFL_TensorOf<[I1]>; def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>; @@ -223,9 +218,9 @@ class TFL_Operand0DOr1ElementTensor : class TFL_TFTypesWithSameBits : And<[ Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa()">, - CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa()">]>, + CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>, Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, - CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">]>]>; + CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_OperandHasRankLessThan : PredOpTrait<"operand " # n # " is maximum " # m # "-D", @@ -602,7 +597,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let verifier = [{ return Verify(*this); }]; } -def TFL_ConstOp : Op { let summary = "Constant pseudo op."; @@ -1863,11 +1858,11 @@ def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutat }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs, + ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let hasFolder = 1; @@ -1887,9 +1882,9 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> { Computes element-wise negation of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_TensorOf<[F32, I32, I64]>:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_TensorOf<[F32, I32, I64]>:$y); let hasOptions = 0b1; @@ -2039,10 +2034,10 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32]>:$lhs, + TFL_TensorOf<[F32, I32]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32]>:$output); let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; @@ -2716,7 +2711,7 @@ def TFL_SplitOp : TFL_Op<"split", [ let arguments = (ins TFL_TensorOf<[I32]>:$split_dim, TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, - PositiveI32Attr:$num_splits + Confined:$num_splits ); let results = (outs @@ -2741,7 +2736,7 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale] TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, TFL_1DTensorOf<[I32], [I32]>:$size_splits, TFL_0DTensorOf<[I32], [I32]>:$split_dim, - PositiveI32Attr:$num_splits + Confined:$num_splits ); let results = (outs @@ -3246,7 +3241,15 @@ Ba et al. 'Layer Normalization' // Since this op is the FULL kernel only, constrain it. Confined< DefaultValuedAttr, - [TFL_LSTM_KT_FULL]>:$kernel_type + [TFL_LSTM_KT_FULL]>:$kernel_type, + + // Types of the optional intermediate tensors, which exist for fully + // quantized LSTM op and hold the ranges of the intermediate tensors. + OptionalAttr:$input_to_input_intermediate, + OptionalAttr:$input_to_forget_intermediate, + OptionalAttr:$input_to_cell_intermediate, + OptionalAttr:$input_to_output_intermediate, + OptionalAttr:$effective_hidden_scale_intermediate ); let results = (outs AnyTensor:$output); @@ -3350,6 +3353,156 @@ def TFL_UnidirectionalSequenceLSTMOp : }]; } +def BidiLstmMandatoryInputsConstraint : PredOpTrait< + "mandatory operands element types should match", + // TODO(ashwinm): Replace the indices with input tensor names when that + // support is available. + Or<[ + TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 19, 20, 21, 23, 24, 25, + 30, 31, 32, 35, 36, 37, 38]>, + Neg>]>>; + +def BidiLstmOptionalPeepholeWeightConstraint : PredOpTrait< + "the optional peephole weights should all be specified or none", + TCopVTEtAreSameAt<[9, 10, 11, 26, 27, 28]>>; + +def BidiLstmProjectionWeightBiasConstraint : PredOpTrait< + "either projection weight must be specified or both projection weight and " + "projection bias must not be specified", + Or<[ + And<[TypeIsPred<"fw_projection_weights", NoneType>, + TypeIsPred<"fw_projection_bias", NoneType>, + TypeIsPred<"bw_projection_weights", NoneType>, + TypeIsPred<"bw_projection_bias", NoneType>]>, + And<[ + Neg>, + Neg>, + ]> + ]>>; + +// BidirectionalSequenceLstm op. +// TODO(ashwinm): Add constraint to validate the combination of operands +// that are valid for hybrid vs fully quantized vs float only semantics +def TFL_BidirectionalSequenceLSTMOp : + TFL_Op<"bidirectional_sequence_lstm", + [BidiLstmMandatoryInputsConstraint, + BidiLstmOptionalPeepholeWeightConstraint, + BidiLstmProjectionWeightBiasConstraint, + LstmResultConstraint, + TFL_StatefulOp]> { + let summary = "Bidirectional sequence lstm operator"; + + let description = [{ + Bidirectional lstm is essentiallay two lstms, one running forward & the + other running backward. And the output is the concatenation of the two + lstms. + }]; + + let arguments = ( + ins TFL_TensorOf<[F32, I8]>:$input, + + // Forward LSTM Weights + TFL_TensorOfOrNone<[F32, I8]>:$fw_input_to_input_weights, + TFL_TensorOf<[F32, I8]>:$fw_input_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$fw_input_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$fw_input_to_output_weights, + + // Forward Recurrent weights + TFL_TensorOfOrNone<[F32, I8]>:$fw_recurrent_to_input_weights, + TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$fw_recurrent_to_output_weights, + + // Forward Cell weights + TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_input_weights, + // Optional Forward cell weights + TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_forget_weights, + // Optional Forward cell weights + TFL_TensorOfOrNone<[F32, I8]>:$fw_cell_to_output_weights, + + // Forward Bias + TFL_TensorOfOrNone<[F32]>:$fw_input_gate_bias, + TFL_TensorOf<[F32]>:$fw_forget_gate_bias, + TFL_TensorOf<[F32]>:$fw_cell_bias, + TFL_TensorOf<[F32]>:$fw_output_gate_bias, + + // Forward Projection weight and bias + TFL_TensorOfOrNone<[F32, I8]>:$fw_projection_weights, + // Forward Optional input + TFL_TensorOfOrNone<[F32]>:$fw_projection_bias, + + // Backward LSTM Weights + TFL_TensorOfOrNone<[F32, I8]>:$bw_input_to_input_weights, + TFL_TensorOf<[F32, I8]>:$bw_input_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$bw_input_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$bw_input_to_output_weights, + + // Backward Recurrent weights + TFL_TensorOfOrNone<[F32, I8]>:$bw_recurrent_to_input_weights, + TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_forget_weights, + TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_cell_weights, + TFL_TensorOf<[F32, I8]>:$bw_recurrent_to_output_weights, + + // Backward Cell weights + TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_input_weights, + // Optional Forward cell weights + TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_forget_weights, + // Optional Forward cell weights + TFL_TensorOfOrNone<[F32, I8]>:$bw_cell_to_output_weights, + + // Backward Bias + TFL_TensorOfOrNone<[F32]>:$bw_input_gate_bias, + TFL_TensorOf<[F32]>:$bw_forget_gate_bias, + TFL_TensorOf<[F32]>:$bw_cell_bias, + TFL_TensorOf<[F32]>:$bw_output_gate_bias, + + // Backward Projection weight and bias + TFL_TensorOfOrNone<[F32, I8]>:$bw_projection_weights, + // Backward Optional input + TFL_TensorOfOrNone<[F32]>:$bw_projection_bias, + + // Stateful activation and cell states. + TFL_StatefulTensor:$fw_input_activation_state, + TFL_StatefulTensor:$fw_input_cell_state, + TFL_StatefulTensor:$bw_input_activation_state, + TFL_StatefulTensor:$bw_input_cell_state, + + // Auxiliary input & weights. + TFL_TensorOfOrNone<[F32, I8]>:$aux_input, + // Auxiliary fw weights. + TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_input_weights, + TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_forget_weights, + TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_cell_weights, + TFL_TensorOfOrNone<[F32, I8]>:$fw_aux_input_to_output_weights, + // Auxiliary bw weights. + TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_input_weights, + TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_forget_weights, + TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_cell_weights, + TFL_TensorOfOrNone<[F32, I8]>:$bw_aux_input_to_output_weights, + + // Attributes + TFL_AFAttr:$fused_activation_function, + DefaultValuedAttr:$cell_clip, + DefaultValuedAttr:$proj_clip, + BoolAttr:$merge_outputs, + BoolAttr:$time_major + ); + + let results = (outs + AnyTensor:$fw_output, + AnyTensor:$bw_output + ); + + let hasOptions = 1; + + let verifier = [{ return Verify(*this); }]; + + let extraClassDeclaration = [{ + // StatefulOpInterface: + std::vector GetStatefulOperands() { return {35, 36, 37, 38}; } + }]; +} + def RnnResultConstraint : PredOpTrait< "the input and result tensor elemental types must be same", TCresVTEtIsSameAsOp<0, 0>>; diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 3080d74ee9c..638884634d5 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -10,11 +10,9 @@ package_group( ) cc_library( - name = "graphdef_to_tfl_flatbuffer", - srcs = ["graphdef_to_tfl_flatbuffer.cc"], - hdrs = [ - "graphdef_to_tfl_flatbuffer.h", - ], + name = "tf_tfl_flatbuffer_helpers", + srcs = ["tf_tfl_flatbuffer_helpers.cc"], + hdrs = ["tf_tfl_flatbuffer_helpers.h"], deps = [ "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", @@ -36,3 +34,61 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "graphdef_to_tfl_flatbuffer", + srcs = ["graphdef_to_tfl_flatbuffer.cc"], + hdrs = [ + "graphdef_to_tfl_flatbuffer.h", + ], + deps = [ + ":tf_tfl_flatbuffer_helpers", + "//tensorflow/compiler/mlir/lite:common", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_flags_proto_cc", + "//tensorflow/lite/toco:types_proto_cc", + "//tensorflow/stream_executor/lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "saved_model_to_tfl_flatbuffer", + srcs = ["saved_model_to_tfl_flatbuffer.cc"], + hdrs = [ + "saved_model_to_tfl_flatbuffer.h", + ], + deps = [ + ":tf_tfl_flatbuffer_helpers", + "//tensorflow/compiler/mlir/lite:common", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_flags_proto_cc", + "//tensorflow/lite/toco:types_proto_cc", + "//tensorflow/stream_executor/lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index a3b71fbe8d8..660f73e59e9 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -40,288 +41,7 @@ limitations under the License. #include "tensorflow/lite/toco/types.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" -using stream_executor::port::StatusOr; - namespace tensorflow { - -namespace { -// Op def string for TFLite_Detection_PostProcess Op. -const char kDetectionPostProcessOp[] = - "name: 'TFLite_Detection_PostProcess' input_arg: { name: " - "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: " - "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: " - "'anchors' type: DT_FLOAT } output_arg: { name: " - "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: " - "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: " - "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: " - "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: " - "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' " - "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { " - "name: 'nms_iou_threshold' type: 'float'} attr : { name: " - "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: " - "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' " - "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: " - "'detections_per_class' type: 'int' default_value { i : 100 }} attr { " - "name: 'use_regular_nms' type: 'bool' default_value { b : false }}"; - -const char kUnidirectionalSequenceLstmOp[] = - "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: " - "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } " - "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { " - "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: " - "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: " - "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: " - "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: " - "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: " - "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: " - "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: " - "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: " - "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' " - "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } " - "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: " - "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' " - "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } " - "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { " - "name: 'InputCellStateTensor' type: DT_FLOAT } " - "output_arg: { name: 'Concat' type: DT_FLOAT} " - "output_arg: { name: " - "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} " - "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; - -const char kUnidirectionalSequenceRnnOp[] = - "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: " - "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } " - "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { " - "name: 'Bias' type: DT_FLOAT} " - "input_arg: { name: 'HiddenState' type: DT_FLOAT} " - "output_arg: { name: " - "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: " - "DT_FLOAT} " - "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; - -// Converts the toco::IODataType to tensorflow::DataType. Only contains the -// conversion mapping for constants defined in TFLite Python API. -DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { - switch (dtype) { - case toco::IODataType::FLOAT: - return DT_FLOAT; - case toco::IODataType::QUANTIZED_UINT8: - return DT_QUINT8; - case toco::IODataType::INT8: - return DT_QINT8; - case toco::IODataType::INT32: - return DT_INT32; - case toco::IODataType::INT64: - return DT_INT64; - case toco::IODataType::STRING: - return DT_STRING; - case toco::IODataType::BOOL: - return DT_BOOL; - default: - return DT_INVALID; - } -} - -StatusOr> InputStatsToMinMax(double mean, double std, - DataType type) { - // Only qint8 and quint8 are considered here. - double qmin, qmax; - if (type == DT_QUINT8) { - qmin = 0.0; - qmax = 255.0; - } else if (type == DT_QINT8) { - qmin = -128.0; - qmax = 127.0; - } else { - return errors::InvalidArgument("Only int8 and uint8 are considered."); - } - return std::make_pair((qmin - mean) / std, (qmax - mean) / std); -} - -// Give a warning for any unused flags that have been specified. -void WarningUnusedFlags(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags) { - if (toco_flags.output_format()) { - LOG(WARNING) << "Ignored output_format."; - } - if (toco_flags.drop_control_dependency()) { - LOG(WARNING) << "Ignored drop_control_dependency."; - } - if (toco_flags.reorder_across_fake_quant()) { - LOG(WARNING) << "Ignored reorder_across_fake_quant."; - } - if (model_flags.change_concat_input_ranges()) { - LOG(WARNING) << "Ignored change_concat_input_ranges."; - } - if (toco_flags.dump_graphviz_include_video()) { - LOG(WARNING) << "Ignored dump_graphviz_video."; - } - if (model_flags.allow_nonexistent_arrays()) { - LOG(WARNING) << "Allow allow_nonexistent_arrays."; - } -} - -// Dumps the op graph of the `module` to `filename` in DOT format. -Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { - std::string error_message; - auto output = mlir::openOutputFile(filename, &error_message); - if (!error_message.empty()) { - return errors::InvalidArgument("Failed to open file in %s.", filename); - } - mlir::PassManager pm(module.getContext()); - pm.addPass(mlir::createPrintOpGraphPass(output->os())); - if (failed(pm.run(module))) { - return errors::Unknown("Failed to dump Op Graph from MLIR module."); - } - output->keep(); - return Status::OK(); -} - -Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { - for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { - return errors::InvalidArgument("fail to parse extra OpDef"); - } - // Make sure the op is not already registered. If registered continue. - const OpRegistrationData* op_reg = - tensorflow::OpRegistry::Global()->LookUp(opdef.name()); - if (op_reg) continue; - - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return Status::OK(); - }); - } - return Status::OK(); -} - -Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { - // Register any custom OpDefs. - std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), - toco_flags.custom_opdefs().end()); - extra_tf_opdefs.push_back(kDetectionPostProcessOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp); - return RegisterCustomBuiltinOps(extra_tf_opdefs); -} - -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs) { - quant_specs->inference_input_type = - ConvertIODataTypeToDataType(toco_flags.inference_input_type()); - tensorflow::DataType inference_type = - ConvertIODataTypeToDataType(toco_flags.inference_type()); - // Use non-float flag `inference_input_type` to override the `inference_type` - // because we have to apply quantization to satisfy that. - if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) { - inference_type = quant_specs->inference_input_type; - } - - for (auto& flag : model_flags.input_arrays()) { - node_names->push_back(flag.name()); - // TOCO doesn't required `data_type` to be filled for every input. - // If it's not filled, make it an empty string so the importer will use - // the data type in the NodeDef. - auto toco_data_type = flag.data_type(); - if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) { - node_dtypes->push_back(""); - } else { - node_dtypes->push_back( - DataType_Name(ConvertIODataTypeToDataType(toco_data_type))); - } - node_shapes->push_back(std::vector(flag.shape().dims().begin(), - flag.shape().dims().end())); - // Currently, only UINT8 and INT8 require inputs stats - if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { - TF_ASSIGN_OR_RETURN( - auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), - inference_type)); - node_mins->push_back(min_max.first); - node_maxs->push_back(min_max.second); - } - } - - if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs, - inference_type, quant_specs)) { - return errors::InvalidArgument("Failed to get input quant spec."); - } - - // Some extra flag related to post training quantization. If post-training - // quantization is enabled, `inference_type` and `inference_input_type` are - // not used by MLIR passes. - if (toco_flags.post_training_quantize()) { - quant_specs->weight_quantization = true; - if (toco_flags.quantize_to_float16()) { - quant_specs->inference_type = tensorflow::DT_HALF; - quant_specs->inference_input_type = tensorflow::DT_HALF; - } else { - quant_specs->inference_type = tensorflow::DT_QINT8; - quant_specs->inference_input_type = tensorflow::DT_QINT8; - } - } - - // Other flags. - if (toco_flags.has_default_ranges_min()) { - quant_specs->default_ranges.first = toco_flags.default_ranges_min(); - } - if (toco_flags.has_default_ranges_max()) { - quant_specs->default_ranges.second = toco_flags.default_ranges_max(); - } - - return ::tensorflow::Status::OK(); -} - -Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, - mlir::OwningModuleRef module, - mlir::TFL::QuantizationSpecs quant_specs, - string* result) { - bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); - bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); - bool emit_custom_ops = toco_flags.allow_custom_ops(); - - if (toco_flags.has_dump_graphviz_dir()) { - TF_RETURN_IF_ERROR(DumpOpGraphToFile( - module.get(), - // rename once we enable the new converter feature flag. - absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot"))); - } - - mlir::PassManager pm(module->getContext()); - mlir::TFL::PassConfig pass_config(quant_specs); - pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; - pass_config.lower_tensor_list_ops = true; - - tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); - // Convert back to outlined while format for export back to flatbuffer. - if (pass_config.legalize_tf_while) { - pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - } - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); - - auto status = ConvertTFExecutorToTFLOrFlatbuffer( - module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); - if (toco_flags.has_dump_graphviz_dir()) { - TF_RETURN_IF_ERROR(DumpOpGraphToFile( - // rename once we enable the new converter feature flag. - module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(), - "/toco_AFTER_TRANSFORMATIONS.dot"))); - } - - return status; -} - -} // namespace - Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, const GraphDebugInfo& debug_info, @@ -339,7 +59,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, std::vector node_maxs; // Populate quantization specs. - TF_RETURN_IF_ERROR(PopulateQuantizationSpecs( + TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes, &node_shapes, &node_mins, &node_maxs)); @@ -356,16 +76,16 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, specs.convert_legacy_fed_inputs = true; specs.graph_as_function = false; specs.upgrade_legacy = true; - WarningUnusedFlags(model_flags, toco_flags); + internal::WarningUnusedFlags(model_flags, toco_flags); // Register all custom ops, including user-specified custom ops. - TF_RETURN_IF_ERROR(RegisterAllCustomOps(toco_flags)); + TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); TF_ASSIGN_OR_RETURN( auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); - return ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - quant_specs, result); + return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), + quant_specs, result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc new file mode 100644 index 00000000000..a546dba3ff3 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" + +#include + +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +Status ConvertSavedModelToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const string& saved_model_dir, bool saved_model_v1, + const string& saved_model_tags, const string& saved_model_exported_names, + string* result) { + mlir::MLIRContext context; + mlir::TFL::QuantizationSpecs quant_specs; + + // Parse input arrays. + std::vector node_names; + std::vector node_dtypes; + std::vector> node_shapes; + std::vector node_mins; + std::vector node_maxs; + + // Populate quantization specs. + TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( + model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes, + &node_shapes, &node_mins, &node_maxs)); + + internal::WarningUnusedFlags(model_flags, toco_flags); + + // Register all custom ops, including user-specified custom ops. + TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); + + const bool import_saved_model = !saved_model_v1; + TF_ASSIGN_OR_RETURN( + auto module, + ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir, + saved_model_tags, saved_model_exported_names, &context)); + return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), + quant_specs, result); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h new file mode 100644 index 00000000000..dea5603dad0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" + +namespace tensorflow { + +// Converts the given saved_model(either v1 or v2) to a TF Lite FlatBuffer +// string according to the given model flags, toco flags and tags. Returns error +// status if it fails to convert the input. +Status ConvertSavedModelToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const string& saved_model_dir, bool saved_model_v1, + const string& saved_model_tags, const string& saved_model_exported_names, + string* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc new file mode 100644 index 00000000000..e0eb8004a01 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -0,0 +1,325 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" + +#include +#include + +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/FileUtilities.h" // TF:llvm-project +#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using stream_executor::port::StatusOr; + +namespace tensorflow { +namespace internal { +namespace { + +// Op def string for TFLite_Detection_PostProcess Op. +const char kDetectionPostProcessOp[] = + "name: 'TFLite_Detection_PostProcess' input_arg: { name: " + "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: " + "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: " + "'anchors' type: DT_FLOAT } output_arg: { name: " + "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: " + "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: " + "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: " + "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: " + "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' " + "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { " + "name: 'nms_iou_threshold' type: 'float'} attr : { name: " + "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: " + "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' " + "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: " + "'detections_per_class' type: 'int' default_value { i : 100 }} attr { " + "name: 'use_regular_nms' type: 'bool' default_value { b : false }}"; + +const char kUnidirectionalSequenceLstmOp[] = + "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: " + "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } " + "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { " + "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: " + "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: " + "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: " + "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: " + "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: " + "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: " + "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: " + "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: " + "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' " + "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } " + "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: " + "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' " + "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } " + "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { " + "name: 'InputCellStateTensor' type: DT_FLOAT } " + "output_arg: { name: 'Concat' type: DT_FLOAT} " + "output_arg: { name: " + "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} " + "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; + +const char kUnidirectionalSequenceRnnOp[] = + "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: " + "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } " + "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { " + "name: 'Bias' type: DT_FLOAT} " + "input_arg: { name: 'HiddenState' type: DT_FLOAT} " + "output_arg: { name: " + "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: " + "DT_FLOAT} " + "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; + +// Converts the toco::IODataType to tensorflow::DataType. Only contains the +// conversion mapping for constants defined in TFLite Python API. +DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { + switch (dtype) { + case toco::IODataType::FLOAT: + return DT_FLOAT; + case toco::IODataType::QUANTIZED_UINT8: + return DT_QUINT8; + case toco::IODataType::INT8: + return DT_QINT8; + case toco::IODataType::INT32: + return DT_INT32; + case toco::IODataType::INT64: + return DT_INT64; + case toco::IODataType::STRING: + return DT_STRING; + case toco::IODataType::BOOL: + return DT_BOOL; + default: + return DT_INVALID; + } +} + +StatusOr> InputStatsToMinMax(double mean, double std, + DataType type) { + // Only qint8 and quint8 are considered here. + double qmin, qmax; + if (type == DT_QUINT8) { + qmin = 0.0; + qmax = 255.0; + } else if (type == DT_QINT8) { + qmin = -128.0; + qmax = 127.0; + } else { + return errors::InvalidArgument("Only int8 and uint8 are considered."); + } + return std::make_pair((qmin - mean) / std, (qmax - mean) / std); +} + +Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Make sure the op is not already registered. If registered continue. + const OpRegistrationData* op_reg = + tensorflow::OpRegistry::Global()->LookUp(opdef.name()); + if (op_reg) continue; + + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} + +} // namespace + +Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { + // Register any custom OpDefs. + std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), + toco_flags.custom_opdefs().end()); + extra_tf_opdefs.push_back(kDetectionPostProcessOp); + extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp); + extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp); + return RegisterCustomBuiltinOps(extra_tf_opdefs); +} + +Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, + std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector* node_mins, + std::vector* node_maxs) { + quant_specs->inference_input_type = + ConvertIODataTypeToDataType(toco_flags.inference_input_type()); + tensorflow::DataType inference_type = + ConvertIODataTypeToDataType(toco_flags.inference_type()); + // Use non-float flag `inference_input_type` to override the `inference_type` + // because we have to apply quantization to satisfy that. + if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) { + inference_type = quant_specs->inference_input_type; + } + + for (auto& flag : model_flags.input_arrays()) { + node_names->push_back(flag.name()); + // TOCO doesn't required `data_type` to be filled for every input. + // If it's not filled, make it an empty string so the importer will use + // the data type in the NodeDef. + auto toco_data_type = flag.data_type(); + if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) { + node_dtypes->push_back(""); + } else { + node_dtypes->push_back( + DataType_Name(ConvertIODataTypeToDataType(toco_data_type))); + } + node_shapes->push_back(std::vector(flag.shape().dims().begin(), + flag.shape().dims().end())); + // Currently, only UINT8 and INT8 require inputs stats + if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { + TF_ASSIGN_OR_RETURN( + auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), + inference_type)); + node_mins->push_back(min_max.first); + node_maxs->push_back(min_max.second); + } + } + + if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs, + inference_type, quant_specs)) { + return errors::InvalidArgument("Failed to get input quant spec."); + } + + // Some extra flag related to post training quantization. If post-training + // quantization is enabled, `inference_type` and `inference_input_type` are + // not used by MLIR passes. + if (toco_flags.post_training_quantize()) { + quant_specs->weight_quantization = true; + if (toco_flags.quantize_to_float16()) { + quant_specs->inference_type = tensorflow::DT_HALF; + quant_specs->inference_input_type = tensorflow::DT_HALF; + } else { + quant_specs->inference_type = tensorflow::DT_QINT8; + quant_specs->inference_input_type = tensorflow::DT_QINT8; + } + } + + // Other flags. + if (toco_flags.has_default_ranges_min()) { + quant_specs->default_ranges.first = toco_flags.default_ranges_min(); + } + if (toco_flags.has_default_ranges_max()) { + quant_specs->default_ranges.second = toco_flags.default_ranges_max(); + } + + return ::tensorflow::Status::OK(); +} + +// Dumps the op graph of the `module` to `filename` in DOT format. +Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { + std::string error_message; + auto output = mlir::openOutputFile(filename, &error_message); + if (!error_message.empty()) { + return errors::InvalidArgument("Failed to open file in %s.", filename); + } + mlir::PassManager pm(module.getContext()); + pm.addPass(mlir::createPrintOpGraphPass(output->os())); + if (failed(pm.run(module))) { + return errors::Unknown("Failed to dump Op Graph from MLIR module."); + } + output->keep(); + return Status::OK(); +} + +Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, + mlir::OwningModuleRef module, + mlir::TFL::QuantizationSpecs quant_specs, + string* result) { + bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); + bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); + bool emit_custom_ops = toco_flags.allow_custom_ops(); + + if (toco_flags.has_dump_graphviz_dir()) { + TF_RETURN_IF_ERROR(DumpOpGraphToFile( + module.get(), + // rename once we enable the new converter feature flag. + absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot"))); + } + + mlir::PassManager pm(module->getContext()); + mlir::TFL::PassConfig pass_config(quant_specs); + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.lower_tensor_list_ops = true; + + tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); + // Convert back to outlined while format for export back to flatbuffer. + if (pass_config.legalize_tf_while) { + pm.addPass(mlir::TFL::CreateWhileOutlinePass()); + } + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + + auto status = ConvertTFExecutorToTFLOrFlatbuffer( + module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); + if (toco_flags.has_dump_graphviz_dir()) { + TF_RETURN_IF_ERROR(DumpOpGraphToFile( + // rename once we enable the new converter feature flag. + module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(), + "/toco_AFTER_TRANSFORMATIONS.dot"))); + } + + return status; +} + +void WarningUnusedFlags(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags) { + if (toco_flags.output_format()) { + LOG(WARNING) << "Ignored output_format."; + } + if (toco_flags.drop_control_dependency()) { + LOG(WARNING) << "Ignored drop_control_dependency."; + } + if (toco_flags.reorder_across_fake_quant()) { + LOG(WARNING) << "Ignored reorder_across_fake_quant."; + } + if (model_flags.change_concat_input_ranges()) { + LOG(WARNING) << "Ignored change_concat_input_ranges."; + } + if (toco_flags.dump_graphviz_include_video()) { + LOG(WARNING) << "Ignored dump_graphviz_video."; + } + if (model_flags.allow_nonexistent_arrays()) { + LOG(WARNING) << "Allow allow_nonexistent_arrays."; + } +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h new file mode 100644 index 00000000000..41846d8e846 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ + +#include +#include + +#include "mlir/IR/Module.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace internal { + +// Register all custom ops including user specified custom ops. +Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags); + +// Populate quantization specs (or not) given user specified ranges for each +// input arrays. +Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, + std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector* node_mins, + std::vector* node_maxs); + +// Convert imported MLIR file to TfLite flatbuffer. +// This will also run relevant passes as well. +Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, + mlir::OwningModuleRef module, + mlir::TFL::QuantizationSpecs quant_specs, + string* result); + +// Give a warning for any unused flags that have been specified. +void WarningUnusedFlags(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags); +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 416c3d1719d..966740e605f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -20,7 +20,7 @@ limitations under the License. #define TF_Quantization include "mlir/IR/OpBase.td" -include "mlir/Dialect/QuantOps/QuantPredicates.td" +include "mlir/Dialect/QuantOps/QuantOpsBase.td" //===----------------------------------------------------------------------===// // QuantizedType definitions. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index a98d50bd07e..a321170349a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include #include #include #include @@ -147,29 +148,45 @@ static bool BroadcastVector(int target_size, SmallVectorImpl& data) { // Changes the axis of the input per-channel quantized type to match the // dimension of the target type. Returns nullptr if it fails. static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( - quant::UniformQuantizedPerAxisType qtype, Type target, int quant_dim) { + ArrayRef shape, quant::UniformQuantizedPerAxisType qtype, + Type target, int quant_dim) { auto shaped = target.dyn_cast(); if (!shaped) return {}; + ArrayRef new_shape = shaped.getShape(); SmallVector scales(qtype.getScales().begin(), qtype.getScales().end()); SmallVector zero_points(qtype.getZeroPoints().begin(), qtype.getZeroPoints().end()); - // Broadcast the scales and zero points to match the target size, which is - // usually the axis-th dimension of the target type. Currently, it covers two - // cases: - // - for Transpose, the data layout is changed so the `dim[axis]` still equals - // to the `scales_size`. The broadcast skips; - // - for Reshape, the data layout isn't changed but the innermost dimension is - // expand to cover the last two original dimensions. Thus we just need to be - // repeated the `scales` dim[2] times to covers the new dim length. - // - // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we - // have to repeat each elements in the `scales` locally dim[3] times. - if (BroadcastVector(shaped.getDimSize(quant_dim), scales) || - BroadcastVector(shaped.getDimSize(quant_dim), zero_points)) { + + if (new_shape.size() == shape.size()) { // same rank + // Broadcast the scales and zero points to match the target size, which is + // usually the axis-th dimension of the target type. Currently, it covers + // two cases: + // - for Transpose, the data layout is changed so the `dim[axis]` still + // equals to the `scales_size`. The broadcast skips; + // - for Reshape, the data layout isn't changed but the innermost dimension + // is expand to cover the last two original dimensions. Thus we just need to + // be repeated the `scales` dim[2] times to covers the new dim length. + // + // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we + // have to repeat each elements in the `scales` locally dim[3] times. + if (BroadcastVector(shaped.getDimSize(quant_dim), scales) || + BroadcastVector(shaped.getDimSize(quant_dim), zero_points)) { + return {}; + } + } else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) { + // This is a trivial shift left, then we shift the quant_dim as well. + if (std::equal(shape.begin(), shape.end(), new_shape.begin()) && + quant_dim == -1) { + quant_dim = shape.size() + quant_dim; + } else { + return {}; + } + } else { return {}; } + return quant::UniformQuantizedPerAxisType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), scales, zero_points, quant_dim, qtype.getStorageTypeMin(), @@ -179,20 +196,21 @@ static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, TypeAttr source, Type target, int axis) { - if (auto source_type = source.getValue().dyn_cast_or_null()) { - auto src_ele_type = source_type.getElementType(); - if (auto quantized_type = src_ele_type.dyn_cast()) { - if (auto qtype = - quantized_type.dyn_cast()) { - quantized_type = ResetAxisAndBroadcast(qtype, target, axis); - if (!src_ele_type) return {}; - } - Type final_type = quantized_type.castFromExpressedType(target); - if (!final_type) return {}; - return TypeAttr::get(final_type); - } + auto source_type = source.getValue().dyn_cast_or_null(); + if (!source_type) return {}; + auto src_ele_type = source_type.getElementType(); + auto qtype = src_ele_type.dyn_cast(); + + // Reset the quantization dimensions if it is per-axis. + if (auto per_axis = + qtype.dyn_cast_or_null()) { + qtype = + ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis); } - return {}; + if (!qtype) return {}; + Type final_type = qtype.castFromExpressedType(target); + if (!final_type) return {}; + return TypeAttr::get(final_type); } Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD index 7616922b613..2c5bed86a84 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD @@ -9,6 +9,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ + "//tensorflow/compiler/aot/...", "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/mlir/lite/...", ], @@ -38,3 +39,29 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "quantize", + srcs = [ + "quantize.cc", + ], + hdrs = [ + "quantize.h", + ], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo", + "//tensorflow/compiler/tf2xla", + "//tensorflow/compiler/tf2xla:mlir_tf2xla", + "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core/platform:status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc new file mode 100644 index 00000000000..4640284fa5c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" + +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" + +namespace mlir { +namespace xla_hlo { + +// Quantizes the model in the computation. +tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, + xla::XlaComputation* computation) { + TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot, + computation->Snapshot()); + + MLIRContext context; + OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context)); + auto status = xla::ConvertHloToMlirHlo( + module.get(), snapshot->mutable_hlo()->mutable_hlo_module()); + if (!status.ok()) { + LOG(ERROR) << "Hlo module import failed: " << status; + return status; + } + + PassManager pm(&context); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createInlinerPass()); + pm.addPass(createSymbolDCEPass()); + pm.addNestedPass(createCSEPass()); + + mlir::StatusScopedDiagnosticHandler diag_handler(&context); + LogicalResult result = pm.run(module.get()); + (void)result; + + module->dump(); + + return tensorflow::Status::OK(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h new file mode 100644 index 00000000000..2ec5dbb02ce --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace xla_hlo { + +// Quantizes the model in the computation. +tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, + xla::XlaComputation* computation); + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD index 4faa8d2efe8..4b6b4212567 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD @@ -3,8 +3,14 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) glob_lit_tests( - data = [":test_utilities"], + data = [ + ":graph_config_files", + ":test_utilities", + ], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss. + }, test_file_exts = ["mlir"], ) @@ -13,7 +19,17 @@ filegroup( name = "test_utilities", testonly = True, data = [ + "//tensorflow/compiler/aot:tfcompile", "//tensorflow/compiler/mlir:tf-opt", "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) + +# Bundle together all the graph files that are used by the tests. +filegroup( + name = "graph_config_files", + srcs = glob( + ["**/*.pbtxt"], + ), +) diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir new file mode 100644 index 00000000000..6b9ccfceddd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir @@ -0,0 +1,15 @@ +# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure + +# TODO(fengliuai): update this file with the progress of the implementation +// CHECK: func @main +// CHECK: %cst = constant dense<0.000000e+00> : tensor +// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor +// CHECK: %cst_1 = constant dense<8> : tensor +// CHECK: %cst_2 = constant dense : tensor +// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> +// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> +// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32> +// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor, tensor, tensor, tensor) -> tensor<2x4xf32> +// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple> +// CHECK: return %4 : tuple> +// CHECK: } diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt new file mode 100644 index 00000000000..1e97c1fa326 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.config.pbtxt @@ -0,0 +1,26 @@ +feed { + id { node_name: "input0" } + shape { + dim { size: 2 } + dim { size: 4 } + } +} +feed { + id { node_name: "input1" } + shape { + dim { size: 2 } + dim { size: 4 } + } +} + +fetch { + id { node_name: "Add/FakeQuantWithMinMaxVars" } + shape { + dim { size: 2 } + dim { size: 4 } + } +} + +conversion_options { + custom_fake_quant_op_calls: true +} diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt new file mode 100644 index 00000000000..6995c861fd0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/fadd_quant.mlir.pbtxt @@ -0,0 +1,218 @@ +node: { + name: "Add/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "Add" + input: "Add/FakeQuantWithMinMaxVars/min" + input: "Add/FakeQuantWithMinMaxVars/max" + attr: { + key: "num_bits" + value: { + i: 8 + } + } + attr: { + key: "narrow_range" + value: { + b: false + } + } +} +node: { + name: "Add/FakeQuantWithMinMaxVars/min" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 0.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node: { + name: "Add/FakeQuantWithMinMaxVars/max" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 127.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node { + name: "Add" + op: "Add" + input: "input0/FakeQuantWithMinMaxVars" + input: "input1/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node: { + name: "input0/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "input0" + input: "input0/FakeQuantWithMinMaxVars/min" + input: "input0/FakeQuantWithMinMaxVars/max" + attr: { + key: "num_bits" + value: { + i: 8 + } + } + attr: { + key: "narrow_range" + value: { + b: false + } + } +} +node: { + name: "input0/FakeQuantWithMinMaxVars/min" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 0.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node: { + name: "input0/FakeQuantWithMinMaxVars/max" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 127.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +node: { + name: "input1/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "input1" + input: "input1/FakeQuantWithMinMaxVars/min" + input: "input1/FakeQuantWithMinMaxVars/max" + attr: { + key: "num_bits" + value: { + i: 8 + } + } + attr: { + key: "narrow_range" + value: { + b: false + } + } +} +node: { + name: "input1/FakeQuantWithMinMaxVars/min" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 0.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node: { + name: "input1/FakeQuantWithMinMaxVars/max" + op: "Const" + attr: { + key: "value" + value: { + tensor: { + dtype: DT_FLOAT + tensor_shape: { + } + float_val: 127.0 + } + } + } + attr: { + key: "dtype" + value: { + type: DT_FLOAT + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 27 +} diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt index 44ef85bfac2..902d1c98cab 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s +# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s # Add two tensor<4xi32> inputs and return the result @@ -90,5 +90,11 @@ versions { # CHECK-EMPTY: # CHECK-NEXT: }, { # CHECK-EMPTY: +# CHECK-NEXT: }, { +# CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +# CHECK-NEXT: } ], +# CHECK-NEXT: metadata: [ { +# CHECK-NEXT: name: "min_runtime_version", +# CHECK-NEXT: buffer: 4 # CHECK-NEXT: } ] # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir index 4225e360d58..a113c318d80 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -61,11 +61,11 @@ func @i64() -> tensor<4xi64> { // the same sort of opaque round-trip we get for complex64, but it might be good // to check -func @uint8() -> tensor<4x!tf.uint8> { +func @uint8() -> tensor<4xui8> { // CHECK-LABEL: @uint8 - // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> - %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8> - return %0 : tensor<4x!tf.uint8> + // CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8> + %0 = "tfl.pseudo_const"() {value = dense<[222, 173, 190, 239]> : tensor<4xui8>} : () -> tensor<4xui8> + return %0 : tensor<4xui8> } func @qi32_per_axis() -> tensor<3x3x!quant.uniform> { diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 6003471f106..f58c0535f7c 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -13,3 +13,16 @@ func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32 // CHECK: return %[[RES0]] } + +// ----- + +func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform>, %arg1: tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform>, %arg10: tensor<2048x!quant.uniform>, %arg11: tensor<2048x!quant.uniform>, %arg12: tensor<2048x!quant.uniform>, %arg13: tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform>, %arg15: tensor<2048x!quant.uniform>, %arg16: tensor<2048x!quant.uniform>, %arg17: tensor<2048x!quant.uniform>, %arg18: tensor<2048x!quant.uniform>, %arg19: tensor<1x640x!quant.uniform>, %arg20: tensor<1x2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> { + %cst = constant unit + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> + return %0 : tensor<1x640x!quant.uniform> +// CHECK-LABEL: testFullyQuantizedLSTM +// CHECK: %cst = constant unit +// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) +// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> +} + diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir index 22d2c39535a..8f30aef8287 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir @@ -58,15 +58,15 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t // CANON-SAME: (tensor, tensor<256x256xf32>, tensor) // CANON: [[VAL_1:%.*]] = constant dense<1.000000e+00> : tensor<256x256xf32> // CANON: [[VAL_2:%.*]] = constant dense<0> : tensor -// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor -// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor -// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor // CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ( { // CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>): +// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor // CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]]) // CANON: "tfl.yield"([[VAL_10]]) : (tensor<*xi1>) -> () // CANON: }, { // CANON: ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>): +// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor +// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor // CANON: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]]) // CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]]) // CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]]) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 1256571c3b4..d236c8169b8 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1,22 +1,11 @@ // RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure -func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { +func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> - %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> - %4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> - %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32> - return %7: tensor<1xf32> + return %0: tensor<1xf32> -// CHECK-LABEL: addRelu +// CHECK-LABEL: add // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> -// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32> -// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32> -// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32> // CHECK: return } @@ -30,13 +19,10 @@ func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> { %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> - %1 = "tf.BiasAdd"(%0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> - %2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32> - return %2 : tensor<1x10x10x32xf32> + return %0 : tensor<1x10x10x32xf32> // CHECK-LABEL: biasAdd // CHECK: "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> -// CHECK: %1 = "tfl.add"(%0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> } func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> { @@ -55,9 +41,9 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor) -> i %4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 return %4 : i32 // CHECK-LABEL: squeezeAndReshape -// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32> // CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32> // CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor) -> tensor<*xf32> +// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32> // CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32> // CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 // CHECK: return @@ -88,7 +74,7 @@ func @dynamicReshapeI64Fold(%arg0: tensor<*xf32>) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> // CHECK-LABEL: dynamicReshapeI64Fold -// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32> +// CHECK: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32> // CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%arg0, %[[cst]]) : (tensor<*xf32>, tensor<2xi32>) -> tensor<1x2xf32> // CHECK-NEXT: return %[[reshape]] : tensor<1x2xf32> } @@ -128,10 +114,10 @@ func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { return %0 : tensor<8x16xf32> // CHECK-LABEL: softplus -// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> -// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> -// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor +// CHECK: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> +// CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> } func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { @@ -255,20 +241,12 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> } -func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { +func @div(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> - %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> - %4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> - return %5: tensor<1xf32> + return %0: tensor<1xf32> -// CHECK-LABEL: divRelu +// CHECK-LABEL: div // CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> -// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32> -// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32> // CHECK: return } @@ -698,8 +676,9 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> +// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -731,8 +710,9 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> +// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -1295,7 +1275,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> - // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> + // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> + // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> @@ -1340,8 +1321,8 @@ func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> { return %0: tensor<8xf16> // CHECK-LABEL: reciprocal_f16 -// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf16> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf16>, tensor<8xf16>) -> tensor<8xf16> +// CHECK: %cst = constant dense<1.000000e+00> : tensor +// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xf16>) -> tensor<8xf16> // CHECK: return } @@ -1350,8 +1331,8 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> { return %0: tensor<8xf32> // CHECK-LABEL: reciprocal_f32 -// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf32> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<8xf32>) -> tensor<8xf32> +// CHECK: %cst = constant dense<1.000000e+00> : tensor +// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xf32>) -> tensor<8xf32> // CHECK: return } @@ -1360,8 +1341,8 @@ func @reciprocal_complex_f32(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex< return %0: tensor<8xcomplex> // CHECK-LABEL: reciprocal_complex_f32 -// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<1xcomplex> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xcomplex>, tensor<8xcomplex>) -> tensor<8xcomplex> +// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor> +// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor>, tensor<8xcomplex>) -> tensor<8xcomplex> // CHECK: return } @@ -1370,8 +1351,8 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> { return %0: tensor<8xi32> // CHECK-LABEL: reciprocal_i32 -// CHECK: %cst = constant dense<1> : tensor<1xi32> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<8xi32>) -> tensor<8xi32> +// CHECK: %cst = constant dense<1> : tensor +// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xi32>) -> tensor<8xi32> // CHECK: return } @@ -1380,8 +1361,8 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> { return %0: tensor<8xi64> // CHECK-LABEL: reciprocal_i64 -// CHECK: %cst = constant dense<1> : tensor<1xi64> -// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64> +// CHECK: %cst = constant dense<1> : tensor +// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xi64>) -> tensor<8xi64> // CHECK: return } @@ -1436,7 +1417,7 @@ func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) // CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32> // CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32> // CHECK: [[VAL_5:%.*]] = constant unit -// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32> +// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32> // CHECK: return [[VAL_6]] : tensor<28x1x16xf32> // CHECK: } @@ -1461,7 +1442,7 @@ func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) { // CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32> // CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32> // CHECK: [[VAL_14:%.*]] = constant unit -// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32> +// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32> // CHECK: return [[VAL_15]] : tensor<28x1x8xf32> // CHECK: } @@ -1480,3 +1461,25 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) { // CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32> // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } + +func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32 +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor +// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32 +// CHECK: [[CST:%.*]] = constant dense<1> : tensor +// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xi32> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> +// CHECK return [[MUL]] : tensor<3x3xi32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir index 39a93e1d03b..3addd8a9248 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2exec/tfl_while_op.mlir @@ -12,11 +12,9 @@ // CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes // CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes // CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes -// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes -// CHECK-NEXT: Tensor 4 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes -// CHECK-NEXT: Tensor 5 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes -// CHECK-NEXT: Tensor 6 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes -// CHECK-NEXT: Tensor 7 tfl.while:3 kTfLiteInt32 kTfLiteArenaRw 4 bytes +// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes +// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes +// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes // Verify while was not folded away: // ------------------------------------ diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir index 5ede7c05234..47e1ccee3c9 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir @@ -108,6 +108,12 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir index 8d4c93fccc0..9d134a3fcad 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s @@ -61,6 +61,12 @@ func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index ec6b9e313f6..1b46fa3d0e5 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): @@ -90,6 +90,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir index 10a62121485..ffa379124e6 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir @@ -82,6 +82,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir index ce079ccccf7..627de564931 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir @@ -84,6 +84,12 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir index 236fc605c9d..13f8b998fff 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): @@ -88,6 +88,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir index 2505f73ee31..48994bf4617 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -tflite-flatbuffer-to-mlir - -o - | FileCheck --check-prefix=IMPORT %s func @main(tensor<4xf32>) -> tensor<4xf32> { @@ -46,6 +46,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir index c98fdeb514e..9c4524586a5 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK: { @@ -39,6 +39,12 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir index 0bde1879b10..6f1bafcd7a9 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): @@ -89,6 +89,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir index 85ad8f01dbe..2015d694e7f 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir @@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir index 6f7fc9c967d..44c757d2fa8 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir @@ -61,6 +61,12 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 8ad0e1b0278..e325262eaa4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -156,6 +156,12 @@ // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 11 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir index fd3f37eec73..4cedc6a218e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4xi1>) -> tensor<4xi1> { ^bb0(%arg0: tensor<4xi1>): @@ -78,6 +78,12 @@ func @main(tensor<4xi1>) -> tensor<4xi1> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 49, 49, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index ed3c8a6f702..2ddb78dd4e5 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { // CHECK: { @@ -192,7 +192,8 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: builtin_options_type: LSTMOptions, // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: -// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: intermediates: [ ] // CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: } ], @@ -249,6 +250,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir new file mode 100644 index 00000000000..6ae8ec8f3c7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir @@ -0,0 +1,323 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s + +func @main(%arg0: tensor<1x528x!quant.uniform>, %arg1: tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform>, %arg10: tensor<2048x!quant.uniform>, %arg11: tensor<2048x!quant.uniform>, %arg12: tensor<2048x!quant.uniform>, %arg13: tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform>, %arg15: tensor<2048x!quant.uniform>, %arg16: tensor<2048x!quant.uniform>, %arg17: tensor<2048x!quant.uniform>, %arg18: tensor<2048x!quant.uniform>, %arg19: tensor<1x640x!quant.uniform>, %arg20: tensor<1x2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> { + %cst = constant unit + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> + return %0 : tensor<1x640x!quant.uniform> +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: LSTM, +// CHECK-NEXT: version: 1 +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 1, 528 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.037248 ], +// CHECK-NEXT: zero_point: [ -19 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 528 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "arg1", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.059802 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 528 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "arg2", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.031926 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 528 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "arg3", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.056272 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 528 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: name: "arg4", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.063764 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "arg5", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.013359 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 7, +// CHECK-NEXT: name: "arg6", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.02283 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 8, +// CHECK-NEXT: name: "arg7", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.032276 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 9, +// CHECK-NEXT: name: "arg8", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.035427 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 10, +// CHECK-NEXT: name: "arg9", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.0 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 11, +// CHECK-NEXT: name: "arg10", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.0 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 12, +// CHECK-NEXT: name: "arg11", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.0 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 13, +// CHECK-NEXT: name: "arg12", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.0 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 640, 2048 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 14, +// CHECK-NEXT: name: "arg13", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.021174 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 640 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 15, +// CHECK-NEXT: name: "arg14", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.00016 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: buffer: 16, +// CHECK-NEXT: name: "arg15", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.000437 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: buffer: 17, +// CHECK-NEXT: name: "arg16", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.00011 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: buffer: 18, +// CHECK-NEXT: name: "arg17", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.000168 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 2048 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: buffer: 19, +// CHECK-NEXT: name: "arg18", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.000156 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: name: "arg19", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.096711 ], +// CHECK-NEXT: zero_point: [ 10 ] +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 2048 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: name: "arg20", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.000488 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 0 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: name: "input_to_input_intermediate", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.004989 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 0 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: name: "input_to_forget_intermediate", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.007885 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 0 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: name: "input_to_cell_intermediate", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.008763 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 0 ], +// CHECK-NEXT: type: INT16, +// CHECK-NEXT: name: "input_to_output_intermediate", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.005753 ], +// CHECK-NEXT: zero_point: [ 0 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 0 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: name: "effective_hidden_scale_intermediate", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.007563 ], +// CHECK-NEXT: zero_point: [ 2 ] +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 640 ], +// CHECK-NEXT: type: INT8, +// CHECK-NEXT: buffer: 22, +// CHECK-NEXT: name: "tfl.lstm", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 0.096711 ], +// CHECK-NEXT: zero_point: [ 10 ] +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], +// CHECK-NEXT: outputs: [ 26 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, 9, 10, 11, 12, 13, 14, 19, 20, 15, 16, 17, 18 ], +// CHECK-NEXT: outputs: [ 26 ], +// CHECK-NEXT: builtin_options_type: LSTMOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: fused_activation_function: TANH, +// CHECK-NEXT: cell_clip: 10.0, +// CHECK-NEXT: proj_clip: 0.01 +// CHECK-NEXT: }, +// CHECK-NEXT: intermediates: [ 21, 22, 23, 24, 25 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 55, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 23 +// CHECK-NEXT: } ] +// CHECK-NEXT: } +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir index d39a8353c6f..6c9dd515ca8 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir @@ -128,6 +128,12 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 8 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir index 47935358512..fc7ef307bae 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { @@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x3 // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir index be2cc62e156..0dc6f7ea165 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> { @@ -50,6 +50,12 @@ func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir index 560d849ece3..8d2f63a8f15 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir @@ -20,6 +20,8 @@ module attributes { // CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ] // CHECK-NEXT: }, { // CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: } ], // CHECK-NEXT: metadata: [ { // CHECK-NEXT: name: "key1", @@ -27,4 +29,8 @@ module attributes { // CHECK-NEXT: }, { // CHECK-NEXT: name: "key2", // CHECK-NEXT: buffer: 5 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir index 2f77163e7a9..3879fc3f1aa 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir @@ -58,6 +58,12 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16xf32>): @@ -47,6 +47,12 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: + // CHECK-NEXT: }, { + // CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] + // CHECK-NEXT: } ], + // CHECK-NEXT: metadata: [ { + // CHECK-NEXT: name: "min_runtime_version", + // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir index 8b2f6ea8b0e..f7830acabf7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // CHECK: { // CHECK-NEXT: version: 3, @@ -40,6 +40,12 @@ // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir index 7d9b113de65..c50857fa2ea 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK: { @@ -153,6 +153,12 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir index c019cf12f05..6ef628229a4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<3x2xi32>) -> tensor<6xi32> { ^bb0(%arg0: tensor<3x2xi32>): @@ -51,6 +51,12 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> { // CHECK-NEXT: data: [ 6, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir index ee731c383f3..148039a1b41 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir @@ -97,6 +97,12 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: data: [ 10, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir index db9668b8e30..559f3745149 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { // CHECK: { @@ -79,6 +79,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir index 8967822e234..ebfd1807280 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { // CHECK: { @@ -80,6 +80,12 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tf_entry_function.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tf_entry_function.mlir new file mode 100644 index 00000000000..f1dc92678ed --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tf_entry_function.mlir @@ -0,0 +1,56 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s + +module { +func @serving_default(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> attributes {tf.entry_function = {inputs = "serving_default_x", outputs = "outputs"}} { +// CHECK: { + +// CHECK-LABEL: version: 3, + +// CHECK-LABEL: operator_codes: [ { +// CHECK: version: 1 +// CHECK: } ], + +// CHECK-LABEL: subgraphs: [ { +// CHECK: tensors: [ { +// CHECK: shape: [ 3, 2 ], +// CHECK: buffer: 1, +// CHECK: name: "serving_default_x", +// CHECK: quantization: { +// CHECK: } +// CHECK: }, { +// CHECK: shape: [ 3, 2 ], +// CHECK: buffer: 2, +// CHECK: name: "tfl.pseudo_const", +// CHECK: quantization: { +// CHECK: } +// CHECK: }, { +// CHECK: shape: [ 3, 2 ], +// CHECK: buffer: 3, +// CHECK: name: "outputs", +// CHECK: quantization: { +// CHECK: } +// CHECK: } ], +// CHECK: inputs: [ 0 ], +// CHECK: outputs: [ 2 ], +// CHECK: operators: [ { +// CHECK: inputs: [ 1, 0 ], +// CHECK: outputs: [ 2 ], +// CHECK: builtin_options_type: AddOptions, +// CHECK: builtin_options: { +// CHECK: } +// CHECK: } ], +// CHECK: name: "main" +// CHECK: } ], +// CHECK-LABEL: description: "MLIR Converted.", +// CHECK-LABEL: buffers: [ { +// CHECK: }, { +// CHECK: }, { +// CHECK: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64 ] +// CHECK: }, { +// CHECK: } ] +// CHECK: } + %0 = "tfl.pseudo_const" () {value = dense<[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + %1 = "tfl.add" (%0, %arg0) {fused_activation_function = "NONE"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir index 3ed6e43479e..bb9278c0d87 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir @@ -189,6 +189,12 @@ // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 019c96cab6c..8e579421b0b 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { // CHECK: { @@ -249,6 +249,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 88e31b2cf78..7ba24bd5c51 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { // CHECK: { @@ -79,6 +79,12 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index 58f19b66370..b40c9fb2044 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // CHECK: { // CHECK-NEXT: version: 3, @@ -189,6 +189,12 @@ // CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 57e2340dd37..995f20c4a07 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -103,7 +103,7 @@ func @testAddN(tensor, tensor, tensor) -> tensor, tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): - // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}} + // expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer or QI16 type or QUI16 type values}} %0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -244,7 +244,7 @@ func @testLogicalNot(tensor) -> tensor { func @testLogicalNotWrongOperandType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{'tfl.logical_not' op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{'tfl.logical_not' op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tfl.logical_not"(%arg0) : (tensor) -> tensor return %0 : tensor } @@ -380,7 +380,7 @@ func @testLogicalAnd(tensor, tensor) -> tensor { func @testLogicalAndWrongOperandType(tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor): - // expected-error @+1 {{'tfl.logical_and' op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{'tfl.logical_and' op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tfl.logical_and"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -399,7 +399,7 @@ func @testLogicalOr(tensor, tensor) -> tensor { func @testLogicalOrWrongOperandType(tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor): - // expected-error @+1 {{'tfl.logical_or' op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{'tfl.logical_or' op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tfl.logical_or"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -593,6 +593,28 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor, return %0 : tensor } +// ----- +// CHECK-LABEL: testLstmIntermediates + + +func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform>, %arg1: tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform>, %arg10: tensor<2048x!quant.uniform>, %arg11: tensor<2048x!quant.uniform>, %arg12: tensor<2048x!quant.uniform>, %arg13: tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform>, %arg15: tensor<2048x!quant.uniform>, %arg16: tensor<2048x!quant.uniform>, %arg17: tensor<2048x!quant.uniform>, %arg18: tensor<2048x!quant.uniform>, %arg19: tensor<1x640x!quant.uniform>, %arg20: tensor<1x2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> { + %cst = constant unit + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> + return %0 : tensor<1x640x!quant.uniform> +// CHECK: %[[RES0:.*]] = constant unit +// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( { +// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: testBidirectionalSequenceLstm +func @testBidirectionalSequenceLstm(%arg0: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor, %arg24: tensor, %arg25: tensor, %arg26: tensor, %arg27: tensor, %arg28: tensor, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor, %arg38: tensor, %arg39: tensor, %arg40: tensor, %arg41: tensor, %arg42: tensor, %arg43: tensor, %arg44: tensor, %arg45: tensor, %arg46: tensor, %arg47: tensor) -> tensor { + // CHECK: "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) + %0:2 = "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) + return %0#0 : tensor +} + // ----- // CHECK-LABEL: testLstmQuantizedType @@ -692,7 +714,7 @@ func @testSelectMultiDim(%cond : tensor, %arg0 : tensor, %arg1 : // ----- func @testSelectWithUnsupportedType(%cond : tensor, %arg0 : tensor, %arg1 : tensor) -> tensor { - // expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tfl.select"(%cond, %arg0, %arg1): (tensor,tensor,tensor) -> tensor return %0 : tensor } @@ -1141,8 +1163,8 @@ func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.uint8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> +func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> return %0 : tensor<1x2x2x5x!tf.quint8> } @@ -1166,7 +1188,7 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, % // ----- func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi8> { - // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit integer or 64-bit integer or 1-bit integer values}} + // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}} %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi8> return %0 : tensor<*xi8> } @@ -1239,7 +1261,7 @@ func @transpose(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi3 // ----- func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) -> tensor<2x2xi32> { - // expected-error @+1 {{op operand #1 must be tensor of 32-bit integer values}} + // expected-error @+1 {{op operand #1 must be tensor of 32-bit signless integer values}} %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xf32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -1322,7 +1344,7 @@ func @transpose_1d_perm(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2x2xi32>) -> ten // ----- func @anyWithI64Axis(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { - // expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit integer values}} + // expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit signless integer values}} %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor return %0 : tensor } @@ -1352,7 +1374,7 @@ func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform // ----- func @whereWithI32Input(%arg0: tensor<3x5xi32>) -> tensor { - // expected-error @+1 {{'tfl.where' op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{'tfl.where' op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tfl.where"(%arg0) : (tensor<3x5xi32>) -> tensor return %0 : tensor } @@ -1559,7 +1581,7 @@ func @testSliceBeginOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) - func @testSplitOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () { %split_dim = constant dense<0> : tensor - // expected-error @+1 {{'tfl.split' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}} + // expected-error @+1 {{'tfl.split' op attribute 'num_splits' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} "tfl.split"(%split_dim, %arg0) {num_splits = 0 : i32} : (tensor, tensor<16xf32>) -> () return } @@ -1682,7 +1704,7 @@ func @testSplitOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor func @testSplitVOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () { %size_splits = constant dense<[]> : tensor<0xi32> %split_dim = constant dense<0> : tensor - // expected-error @+1 {{'tfl.split_v' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}} + // expected-error @+1 {{'tfl.split_v' op attribute 'num_splits' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 0 : i32} : (tensor<16xf32>, tensor<0xi32>, tensor) -> () return } @@ -1702,7 +1724,7 @@ func @testSplitVOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8x func @testSplitVOpWithBadSizeSplitsTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> { %size_splits = constant dense<[[8, 8], [2, 2]]> : tensor<2x2xi32> %split_dim = constant dense<0> : tensor - // expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}} + // expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit signless integer values}} %0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2x2xi32>, tensor) -> tensor<16x4x4xf32> return %0 : tensor<16x4x4xf32> } @@ -1711,7 +1733,7 @@ func @testSplitVOpWithBadSizeSplitsTensorType(%arg0: tensor<16x4x4xf32>) -> tens func @testSplitVOpWithBadSizeSplitsUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %size_splits: tensor<*xi32>) -> tensor<16x4x4xf32> { %split_dim = constant dense<0> : tensor - // expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}} + // expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit signless integer values}} %0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<*xi32>, tensor) -> tensor<16x4x4xf32> return %0 : tensor<16x4x4xf32> } @@ -1761,7 +1783,7 @@ func @testSplitVOpWithBadSizeSplitsSize(%arg0: tensor<16x4x4xf32>) -> tensor<15x func @testSplitVOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> { %size_splits = constant dense<[16]> : tensor<1xi32> %split_dim = constant dense<0> : tensor<2x2xi32> - // expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}} + // expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit signless integer values}} %0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<2x2xi32>) -> tensor<16x4x4xf32> return %0 : tensor<16x4x4xf32> } @@ -1770,7 +1792,7 @@ func @testSplitVOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor func @testSplitVOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> { %size_splits = constant dense<[16]> : tensor<1xi32> - // expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}} + // expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit signless integer values}} %0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<*xi32>) -> tensor<16x4x4xf32> return %0 : tensor<16x4x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index aaf2664ea3c..ae5bd6ced5e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -3,6 +3,9 @@ // Run optimize pass and then canonicalize pass, and make sure some folding is applied. // RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s +// Run legalize pass and then optimize pass, and make sure some fusing is applied. +// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize | FileCheck --check-prefix=Fusing --dump-input-on-failure %s + // CHECK-LABEL: fusedConv2dRelu func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> { %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> @@ -680,6 +683,18 @@ func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> { // CHECK: %[[TILE:[0-9].*]] = "tfl.tile" } +// CHECK-LABEL: InvalidFuseTileAlreadyBroadcastAlongTileDim +func @InvalidFuseTileAlreadyBroadcastAlongTileDim(%arg0: tensor<1x1x1x1xf32>) -> tensor<1x6x8x1xf32> { + %cst_1 = constant dense<[1, 6, 8, 1]> : tensor<4xi32> + %cst_2 = constant dense<[1, 1, 1, 46]> : tensor<4xi32> + %cst_20 = constant dense<4.600000e+01> : tensor + %0 = "tfl.tile"(%arg0, %cst_1) : (tensor<1x1x1x1xf32>, tensor<4xi32>) -> tensor<1x6x8x1xf32> + %1 = "tfl.mul"(%0, %cst_20) {fused_activation_function = "NONE"} : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> + return %1 : tensor<1x6x8x1xf32> + + // CHECK: %[[TILE:[0-9].*]] = "tfl.tile" +} + // CHECK-LABEL: FuseHardswish func @FuseHardswish(%arg0: tensor<1x112x112x16xf32>) -> tensor<1x56x56x16xf32> { %cst_0 = constant dense<3.0> : tensor @@ -835,3 +850,51 @@ func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: // CHECK: tfl.add // CHECK-NEXT: tfl.add } + +func @FusingaddRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> + %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32> + return %7: tensor<1xf32> + +// Fusing-LABEL: FusingaddRelu +// Fusing: %[[add:[0-9].*]] = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> +// Fusing: %[[add1:[0-9].*]] = tfl.add %arg0, %[[add]] {fused_activation_function = "RELU"} : tensor<1xf32> +// Fusing: %[[relu:[0-9].*]] = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> +// Fusing: %[[add2:[0-9].*]] = tfl.add %[[relu]], %[[add1]] {fused_activation_function = "RELU6"} : tensor<1xf32> +// Fusing: %[[add3:[0-9].*]] = tfl.add %[[add2]], %[[relu]] {fused_activation_function = "RELU6"} : tensor<1xf32> +// Fusing: return +} + +func @FusingbiasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> { + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> + %1 = "tf.BiasAdd"(%0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> + %2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32> + return %2 : tensor<1x10x10x32xf32> + +// Fusing-LABEL: FusingbiasAdd +// Fusing: %[[add:[0-9].*]] = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> +// Fusing: %[[add1:[0-9].*]] = "tfl.add"(%[[add]], %arg1) {fused_activation_function = "RELU6"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> +} + +func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> + return %5: tensor<1xf32> + +// Fusing-LABEL: FusingdivRelu +// Fusing: %[[div:[0-9].*]] = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> +// Fusing: %[[div1:[0-9].*]] = tfl.div %arg0, %[[div]] {fused_activation_function = "RELU"} : tensor<1xf32> +// Fusing: %[[relu:[0-9].*]] = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> +// Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32> +// Fusing: return +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_functional_ops.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_functional_ops.mlir index 846f0126f21..dfd0c870a22 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_functional_ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_functional_ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-optimize-functional-ops -split-input-file | FileCheck %s +// RUN: tf-opt %s -tfl-optimize-functional-ops -split-input-file | FileCheck %s --dump-input-on-failure // CHECK-LABEL: main func @main(%arg0: tensor, %arg1: tensor) -> (tensor) { @@ -131,5 +131,80 @@ func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32> // CHECK: func @main // CHECK-NOT: tf.If // CHECK: return -// CHECK-NOT: func else_branch -// CHECK-NOT: func then_branch +// CHECK-NOT: func @_functionalize_if_else_branch_00 +// CHECK-NOT: func @_functionalize_if_then_branch_00 + +// ----- + +// Verify unused if with function with side-effects is not removed. + +func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32> + attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} { + %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %cst_0 = constant dense<1.000000e+00> : tensor + %cst_1 = constant dense<0.000000e+00> : tensor<8xf32> + %cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32> + %0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor) -> tensor<3x15x14x3xf32> + %1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1> + %2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor + %3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1, + else_branch = @_functionalize_if_else_branch_01, is_stateless = false, + then_branch = @_functionalize_if_then_branch_01} : + (tensor, tensor, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor + %4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32> + return %4 : tensor<3x15x14x8xf32> +} + +func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor { + %cst = constant dense : tensor + return %cst : tensor +} + +func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor { + %0 = "my_unknown_op.blah"() : () -> tensor + return %0 : tensor +} + +// CHECK: func @main +// CHECK: tf.If +// CHECK: return +// CHECK: func @_functionalize_if_else_branch_01 +// CHECK: func @_functionalize_if_then_branch_01 + +// ----- + +// Verify unused if with function with side-effects is removed if op says +// stateless. + +func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32> + attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} { + %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %cst_0 = constant dense<1.000000e+00> : tensor + %cst_1 = constant dense<0.000000e+00> : tensor<8xf32> + %cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32> + %0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor) -> tensor<3x15x14x3xf32> + %1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1> + %2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor + %3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1, + else_branch = @_functionalize_if_else_branch_02, is_stateless = true, + then_branch = @_functionalize_if_then_branch_02} : + (tensor, tensor, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor + %4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32> + return %4 : tensor<3x15x14x8xf32> +} + +func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor { + %cst = constant dense : tensor + return %cst : tensor +} + +func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor { + %0 = "my_unknown_op.blah"() : () -> tensor + return %0 : tensor +} + +// CHECK: func @main +// CHECK-NOT: tf.If +// CHECK: return +// CHECK-NOT: func @_functionalize_if_else_branch_02 +// CHECK-NOT: func @_functionalize_if_then_branch_02 diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index c34cfdf441c..83ab0f9cd0e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -42,10 +42,10 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // CHECK-SAME: [[VAL_0]]: tensor<1x?xf32>, [[VAL_1]]: tensor<3x4xf32>, [[VAL_3:%.*]]: tensor<2xf32>, [[VAL_4:%.*]]: tensor<1x3xf32>, [[VAL_5:%.*]]: tensor) -> tensor<1x?xf32> // CHECK-LABEL: attributes {tf._implements = "LSTMCellSimple", tf._reference = "mlir"} { -// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_6]]) : (tensor<3x4xf32>, tensor<2xi64>) -> tensor<4x3xf32> -// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<1x3xf32>, tensor<2xi64>) -> tensor<3x1xf32> +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_6]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32> // CHECK: [[VAL_10:%.*]] = constant unit // CHECK: [[VAL_11:%.*]] = constant dense<0> : tensor<2xi64> // CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -94,10 +94,10 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // CHECK-SAME: [[VAL_0]]: tensor<1x?xf32>, [[VAL_1]]: tensor<3x4xf32>, [[VAL_3]]: tensor<2xf32>, [[VAL_4]]: tensor<1x3xf32>, [[VAL_5]]: tensor<2xf32>) -> tensor<1x?xf32> // CHECK-LABEL: attributes {tf._implements = "LayerNormalizedLstmCellSimple", tf._reference = "mlir"} { -// CHECK: [[VAL_52:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_53:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_52]]) : (tensor<3x4xf32>, tensor<2xi64>) -> tensor<4x3xf32> -// CHECK: [[VAL_54:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_55:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_54]]) : (tensor<1x3xf32>, tensor<2xi64>) -> tensor<3x1xf32> +// CHECK: [[VAL_52:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_53:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_52]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> +// CHECK: [[VAL_54:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_55:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_54]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32> // CHECK: [[VAL_56:%.*]] = constant unit // CHECK: [[VAL_57:%.*]] = constant dense<0> : tensor<2xi64> // CHECK: [[VAL_58:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -165,11 +165,11 @@ func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { -// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> -// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> // CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) @@ -181,11 +181,14 @@ func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit // CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor -// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor, tensor, tensor, tensor, tensor +// CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor // CHECK: } } @@ -203,32 +206,32 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { -// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> -// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> -// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> -// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> -// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) -// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_21:%.*]] = constant unit -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> -// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> -// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_22]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> +// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> +// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_19:%.*]] = constant unit +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_21:%.*]] = constant dense<[0, -1, 0]> : tensor<3xi32> +// CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> // CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor // CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_25]], [[VAL_24]], [[VAL_26]], [[VAL_27]], [[VAL_28]] : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor // CHECK: } + } // ----- @@ -245,13 +248,13 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32> // CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor, tensor<1xi32>) -> tensor -// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> -// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> // CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) @@ -262,12 +265,15 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, // CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = constant unit -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor -// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_23]], [[VAL_22]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor, tensor, tensor, tensor, tensor +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_23:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_25:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor // CHECK: } } @@ -286,33 +292,32 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3 return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { -// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> -// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> -// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32> -// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> -// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> -// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> -// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) -// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_23:%.*]] = constant unit -// CHECK: [[VAL_24:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> -// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> -// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_24]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> +// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: [[VAL_6:%.*]] = constant dense<1> : tensor<1xi32> +// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> +// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_21:%.*]] = constant unit +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_23:%.*]] = constant dense<[0, -1, 0]> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_25:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> // CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor // CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_27]], [[VAL_26]], [[VAL_28]], [[VAL_29]], [[VAL_30]] : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor +// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor // CHECK: } } @@ -338,11 +343,11 @@ func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { -// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> -// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> -// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> // CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) @@ -354,11 +359,65 @@ func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit // CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor -// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor, tensor, tensor, tensor, tensor +// CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: } + +} + +// ----- + +module { +func @inference_can_fuse_last_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor} : () -> tensor + %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse_last_output} : (tensor, tensor, tensor, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) + %2 = "tf.Add"(%0, %1#0) : (tensor, tensor<8x10xf32>) -> tensor<8x10xf32> + return +} + +func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor + %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor + %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + %7 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor<8x10xf32> + return %7, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor, tensor, tensor +} + +// CHECK: func @inference_standard_lstm_time_major_can_fuse_last_output([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32> +// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_19:%.*]] = constant unit +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: } } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 1aa1311318a..5e456b1a7e5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -377,6 +377,32 @@ func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor // CHECK: return %[[CONV]] } +// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2DWithReshape +func @perChannelFakeQuantWithDepthwiseConv2DWithReshape(%arg: tensor<1x160x160x48xf32>) -> (tensor<1x160x160x48xf32>) { + %in = constant dense<0.0> : tensor<3x3x48x1xf32> + %min = constant dense<0.0> : tensor<48xf32> + %max = constant dense<255.0> : tensor<48xf32> + %mini = "tf.Identity"(%min) : (tensor<48xf32>) -> tensor<48xf32> + %maxi = "tf.Identity"(%max) : (tensor<48xf32>) -> tensor<48xf32> + %s1 = constant dense<[3, 3, 48]> : tensor<3xi32> + %s2 = constant dense<[3, 3, 48, 1]> : tensor<4xi32> + %r1 = "tf.Reshape"(%in, %s1) {T = f32, Tshape = i32, device = ""} : (tensor<3x3x48x1xf32>, tensor<3xi32>) -> tensor<3x3x48xf32> + %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%r1, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x48xf32>, tensor<48xf32>, tensor<48xf32>) -> tensor<3x3x48xf32> + %r2 = "tf.Reshape"(%fq, %s2) {T = f32, Tshape = i32, device = ""} : (tensor<3x3x48xf32>, tensor<4xi32>) -> tensor<3x3x48x1xf32> + %rst = "tf.DepthwiseConv2dNative"(%arg, %r2) {T = f32, data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x160x160x48xf32>, tensor<3x3x48x1xf32>) -> tensor<1x160x160x48xf32> + return %rst : tensor<1x160x160x48xf32> + +// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32> +// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32> +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} +// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) +// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) +// CHECK: return %[[CONV]] +} + func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) { %0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32> %1:2 = "tf.IdentityN"(%arg1,%arg2) : (tensor<20xi32>, tensor<30xi32>) -> (tensor<20xi32>, tensor<30xi32>) diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index b000de17020..a80a1612488 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -84,11 +84,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); } - // Enable fusing composite ops that can be lowered to built-in TFLite ops. - if (pass_config.emit_builtin_tflite_ops) { - pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); - } - // The ophint extractions happen before lots of other passes: // The assumption of ophint-extraction is each ophinted region is a black-box // and nodes within this black-box is NOT connected to the nodes OUTSIDE the @@ -104,6 +99,27 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass()); } + // This decomposes resource ops like ResourceGather into read-variable op + // followed by gather. This is used when the saved model import path is used + // during which resources dont get frozen in the python layer. + pass_manager->addNestedPass( + mlir::TFDevice::CreateDecomposeResourceOpsPass()); + + // This pass does resource analysis of saved model global tensors and marks + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // This pass marks non-exported functions as symbol visibility 'private' + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model:: + CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); + + // Enable fusing composite ops that can be lowered to built-in TFLite ops. + if (pass_config.emit_builtin_tflite_ops) { + pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); + } + // Legalize while early to allow further constant folding. // TODO(jpienaar): This may not actually matter as we do canonicalization // after the legalize below, for now it needs to be below the above passes @@ -114,6 +130,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreateLegalizeTFWhilePass()); } + // Add function inlining pass. Both TF and TFLite dialects are opted into + // function inliner interface. + pass_manager->addPass(mlir::createInlinerPass()); + // TODO(jpienaar): Revise post dialect constants. pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); // Canonicalization includes const folding, which is utilized here to optimize @@ -121,9 +141,15 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. pass_manager->addNestedPass(mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); + // This pass does dead code elimination based on symbol visibility. + pass_manager->addPass(mlir::createSymbolDCEPass()); + // This pass 'freezes' immutable global tensors and inlines them as tf + // constant ops. + pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); - if (pass_config.inline_functions) { - pass_manager->addPass(mlir::createInlinerPass()); + if (pass_config.shape_inference) { + // Add a shape inference pass to optimize away the unnecessary casts. + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } // The below passes only make sense if Builtin TFLite ops are enabled @@ -160,3 +186,85 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, } } // namespace tensorflow + +namespace mlir { +namespace TFL { + +struct StandardPipelineOptions + : public PassPipelineOptions { + // TODO(b/150915052): All the tf_tfl_translate_cl flags should + // move inside this. +}; + +// NOLINTNEXTLINE +// This creates the standard pass pipeline for TF->TFLite. This +// represents a std configuration for TFLite, for use with APIs like +// tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline +// This does not yet include quantization passes. +void CreateTFLStandardPipeline(OpPassManager& pm, + const StandardPipelineOptions& options) { + OpPassManager& func_pm = pm.nest(); + + // tf_executor dialect passes - Cleaning up the IR. + func_pm.addPass(tf_executor::CreateSwitchFoldPass()); + func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass()); + func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); + + // more cleanup of executor dialect and raise to control flow. + pm.addPass(mlir::CreateTFExecutorToControlDialectConversion()); + pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); + + // This is needed for control flow support with TF TensorList. + pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass()); + + // Saved model pass to mark global tensors immutable. + pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // Used to mark non-exported functions in saved model private. + pm.addPass(mlir::tf_saved_model:: + CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); + // Op fusion pass. + pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); + + pm.addNestedPass(mlir::TFL::CreateLegalizeTFWhilePass()); + + pm.addPass(mlir::createInlinerPass()); + + // Canonicalize, CSE etc. + pm.addPass(mlir::TF::CreateDecodeConstantPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createCSEPass()); + // DCE for private symbols. + pm.addPass(mlir::createSymbolDCEPass()); + + // freeze global tensors. + pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); + + // TFLite dialect passes. + pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TFL::CreateLegalizeTFPass()); + pm.addPass(mlir::TFL::CreateOptimizePass()); + pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); + + // Canonicalize, CSE etc. + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createCSEPass()); + + // Pass for stateful operands like LSTM. + pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass()); + + pm.addPass(mlir::TFL::CreateWhileOutlinePass()); + + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); +} + +// Registers a pass pipeline for the standard TFL passes. +static mlir::PassPipelineRegistration pipeline( + "tfl-standard-pipeline", + "Run the standard passes involved in transforming/optimizing the TF " + "program to TFLite after " + "importing into MLIR.", + CreateTFLStandardPipeline); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 7f8ce4cf3d4..74e48cd6d91 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -137,13 +137,14 @@ int main(int argc, char **argv) { // TODO(b/147435528): We need to test the e2e behavior once the graph freezing // inside mlir is done. - if (import_saved_model || import_saved_model_v1) { + if (import_saved_model_object_graph || import_saved_model_signature_defs) { if (input_mlir) module = tensorflow::errors::InvalidArgument( "Importing saved model should not have input_mlir set"); - module = tensorflow::ImportSavedModel( - import_saved_model, import_saved_model_v1, input_file_name, - saved_model_tags, saved_model_exported_names, &context); + module = tensorflow::ImportSavedModel(import_saved_model_object_graph, + import_saved_model_signature_defs, + input_file_name, saved_model_tags, + saved_model_exported_names, &context); } else { module = tensorflow::LoadFromGraphdefOrMlirSource( input_file_name, input_mlir, use_splatted_constant, custom_opdefs, @@ -194,9 +195,18 @@ int main(int argc, char **argv) { mlir::TFL::PassConfig pass_config(quant_specs); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = lower_tensor_list_ops; - pass_config.inline_functions = inline_functions; + + // Currently we only do shape inference for saved model import. + if (import_saved_model_object_graph || import_saved_model_signature_defs) { + pass_config.shape_inference = true; + } tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); + // TODO(b/150901738): Move those into tf_tfl_translate.cc. + // Convert back to outlined while format for export back to flatbuffer. + if (pass_config.legalize_tf_while) { + pm.addPass(mlir::TFL::CreateWhileOutlinePass()); + } pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); std::string result; diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index de569a3496c..e4687c515ac 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -24,14 +24,14 @@ opt input_file_name(llvm::cl::Positional, llvm::cl::init("-")); // NOLINTNEXTLINE -opt import_saved_model( - "savedmodel-to-mlir", +opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", llvm::cl::desc("Import a saved model to its MLIR representation"), llvm::cl::value_desc("dir")); // NOLINTNEXTLINE -opt import_saved_model_v1( - "savedmodel-v1-to-mlir", +opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", llvm::cl::desc("Import a saved model V1 to its MLIR representation"), llvm::cl::value_desc("dir")); @@ -104,13 +104,6 @@ opt quant_stats_file_name("quant-stats", llvm::cl::value_desc("filename"), llvm::cl::init("")); -// NOLINTNEXTLINE -opt inline_functions( - "inline", - llvm::cl::desc("Inline function calls within the main function " - "before legalization to TFLite."), - llvm::cl::init(true)); - // NOLINTNEXTLINE opt legalize_while( "legalize-tf-while", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h index d7e54d70b81..b42160a4a2a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -35,14 +35,13 @@ extern llvm::cl::opt output_file_name; extern llvm::cl::opt use_splatted_constant; extern llvm::cl::opt input_mlir; extern llvm::cl::opt output_mlir; -extern llvm::cl::opt inline_functions; extern llvm::cl::list custom_opdefs; extern llvm::cl::opt emit_quant_adaptor_ops; extern llvm::cl::opt quant_stats_file_name; // Import saved model. -extern llvm::cl::opt import_saved_model; -extern llvm::cl::opt import_saved_model_v1; +extern llvm::cl::opt import_saved_model_object_graph; +extern llvm::cl::opt import_saved_model_signature_defs; extern llvm::cl::opt saved_model_tags; extern llvm::cl::opt saved_model_exported_names; #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index f5097e1c01b..b05dcaadab2 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -169,7 +169,7 @@ StatusOr ImportSavedModel( std::vector exported_names = absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); - auto module = tensorflow::SavedModelToMlirImport( + auto module = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, absl::Span(exported_names), context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); @@ -179,8 +179,8 @@ StatusOr ImportSavedModel( std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); - auto module = - tensorflow::SavedModelV1ToMlirImport(input_filename, tags, context); + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 3582046f13f..5893d4f3779 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -426,7 +426,9 @@ void PreprocessTopoSortGraph( } bool IsSideEffectOp(Operation* op) { - if (op->hasNoSideEffect()) return false; + // TODO(riverriddle) Properly handle region side effects. + if (MemoryEffectOpInterface::hasNoEffect(op) && op->getNumRegions() == 0) + return false; // Identity op has no side effect. // Check the OperationName maybe more elegant here. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 683905d06c7..586ddf6211f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -206,32 +206,12 @@ def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>; def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; -// Multi-pattern consisting of matching stand-alone op or op followed by relu. -// TODO(karimnosseir): Can the activation part here be removed by modifying the -// very similar pass in optimize_patterns.td? -multiclass FusedBinaryActivationFuncOpPat { - def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), - (ToOp $l, $r, TFL_AF_None)>; - foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu], - [TF_Relu6Op, TFL_AF_Relu6]] in { - def : Pat<(actFnPair[0] (FromOp:$bin_out $lhs, $rhs)), - (ToOp $lhs, $rhs, actFnPair[1]), - [(HasOneUse $bin_out)]>; - // TODO: Maybe move these below to general pass? - def : Pat<(actFnPair[0] (ToOp:$bin_out $lhs, $rhs, TFL_AF_None)), - (ToOp $lhs, $rhs, actFnPair[1]), - [(HasOneUse $bin_out)]>; - } -} - -// Instantiated FusedBinary patterns for the from-to pairs of ops. -foreach fromToPair = [[TF_AddOp, TFL_AddOp], - [TF_AddV2Op, TFL_AddOp], - [TF_DivOp, TFL_DivOp], - [TF_MulOp, TFL_MulOp], - [TF_RealDivOp, TFL_DivOp], - [TF_SubOp, TFL_SubOp]] in - defm : FusedBinaryActivationFuncOpPat; +def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; +def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; +def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; +def : Pat<(TF_DivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format), diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index cf24ed7e0f4..d2001db8b40 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -122,6 +123,7 @@ DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Reciprocal); DECL_CONVERT_OP(RandomUniform); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -464,8 +466,7 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite( // TF Lite doesn't support Assert, we just drop the assert from the graph. PatternMatchResult ConvertTFAssertOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { - op->dropAllReferences(); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } @@ -474,8 +475,7 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, ShapedType shaped_type, int value) { Type element_type = shaped_type.getElementType(); - ShapedType ranked_tensor_type = RankedTensorType::get({1}, element_type); - Type type = ranked_tensor_type; + ShapedType scalar_type = RankedTensorType::get({}, element_type); Attribute attr; switch (element_type.getKind()) { case mlir::StandardTypes::F16: { @@ -483,12 +483,12 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(ranked_tensor_type, floatValues); + attr = DenseElementsAttr::get(scalar_type, floatValues); break; } case mlir::StandardTypes::F32: { - attr = DenseElementsAttr::get(ranked_tensor_type, - static_cast(value)); + attr = + DenseElementsAttr::get(scalar_type, static_cast(value)); break; } case mlir::StandardTypes::Complex: { @@ -509,8 +509,7 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, repr.set_tensor_content(content); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - attr = - mlir::OpaqueElementsAttr::get(dialect, ranked_tensor_type, mangled); + attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); break; } return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); @@ -519,19 +518,19 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, const auto& itype = element_type.cast(); switch (itype.getWidth()) { case 8: - attr = DenseElementsAttr::get(ranked_tensor_type, + attr = DenseElementsAttr::get(scalar_type, static_cast(value)); break; case 16: - attr = DenseElementsAttr::get(ranked_tensor_type, + attr = DenseElementsAttr::get(scalar_type, static_cast(value)); break; case 32: - attr = DenseElementsAttr::get(ranked_tensor_type, + attr = DenseElementsAttr::get(scalar_type, static_cast(value)); break; case 64: - attr = DenseElementsAttr::get(ranked_tensor_type, + attr = DenseElementsAttr::get(scalar_type, static_cast(value)); break; default: @@ -543,7 +542,7 @@ StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, default: return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); } - return rewriter->create(loc, type, attr); + return rewriter->create(loc, scalar_type, attr); } PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( @@ -566,6 +565,31 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite( return matchSuccess(); } +PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tf_broadcast_to_op = cast(op); + auto element_type = tf_broadcast_to_op.input().getType().cast(); + auto output_type = tf_broadcast_to_op.output().getType(); + + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); + if (!status_or_const_op.ok()) { + return matchFailure(); + } + + auto tfl_fill_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.shape(), + status_or_const_op.ValueOrDie()); + + StringAttr fused_activation_function = + StringAttr::get("NONE", rewriter.getContext()); + + rewriter.replaceOpWithNewOp( + op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, + fused_activation_function); + return matchSuccess(); +} + // Legalize unidirectional sequence lstm. struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) @@ -616,7 +640,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { rewriter.getStringAttr("TANH"))); // cell_clip. attributes.push_back( - rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(10.0))); + rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0))); // proj_clip. attributes.push_back( rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0))); @@ -629,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { // Rewire the output. op->getResult(2).replaceAllUsesWith(lstm_op.getResult()); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -688,7 +712,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { // Rewire the output. op->getResult(1).replaceAllUsesWith(rnn_op.getResult()); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } @@ -696,22 +720,44 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { void LegalizeTF::runOnFunction() { OwningRewritePatternList patterns; - auto* ctx = &getContext(); + auto* context = &getContext(); auto func = getFunction(); // Add the generated patterns to the list. - populateWithGenerated(ctx, &patterns); - patterns - .insert(ctx); + populateWithGenerated(context, &patterns); + patterns.insert(context); // Ophint python converter converted tf node pattern. patterns.insert(ctx); - applyPatternsGreedily(func, patterns); + LegalizeUnidirectionalSequenceRnn>(context); + + ConversionTarget target(*context); + // It is legal to have TF ops in the graph still which can be + // used later or in the case of SELECT were we allow TF ops in the final + // graph. + target.addLegalOp(); + target.addLegalOp(); + target.addDynamicallyLegalDialect( + Optional([](Operation* op) { + auto tfl_op = dyn_cast_or_null(op); + if (!tfl_op) return false; + return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation())); + })); + // Keep trying to convert. + // TODO(karimnosseir): This is similar to what apply greedy patterns does. + // Look if there is a function that tries until it converge. + // Currently unit-test doesn't do multiple tries, so we need this. + const int max_iterations = 15; + for (int i = 0; i < max_iterations; ++i) { + if (failed(applyPartialConversion(func, target, patterns))) { + return; + } + } } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 754333b175f..a13490ddb9f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -864,13 +864,12 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); OwningRewritePatternList patterns; - patterns - .insert( - context); + populateWithGenerated(context, &patterns); + patterns.insert(context); return applyFullConversion(func, target, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index dbc12a85b67..bc39c0cf74b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -83,6 +83,17 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); } +// Returns whether the resultant type of any broadcastable operation with +// operands `a` and `b` matches `expected_output`. Returns false if `a` is not +// broadcast-compatible with `b`. +bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { + Type output_element_type = + expected_output.cast().getElementType(); + Type broadcasted_type = + OpTrait::util::getBroadcastedType(a, b, output_element_type); + return broadcasted_type != Type() && broadcasted_type == expected_output; +} + // Returns whether if `type1` dimensions are the same as the ending dimensions // of `type2`. This is more restricted than broadcastable. bool IsTailOfShape(Type type1, Type type2) { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index cde253b6ebc..83ecf0be820 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -58,12 +58,14 @@ static void UpdateFuncType(FuncOp func) { // TODO(jpienaar): Remove when recursive side-effect modeling is added. static bool IsSideEffectFree(FuncOp func) { - return func.getBody() - .walk([&](Operation* op) { - if (!op->hasNoSideEffect()) return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted(); + return !func.getBody() + .walk([&](Operation* op) { + if (!MemoryEffectOpInterface::hasNoEffect(op) && + !op->isKnownTerminator()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); } // Folds TensorFlow If op with constant conditional operand by inlining the diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 0ad5be055dc..144227b06af 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -285,6 +285,10 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] def AreBroadcastableTypes : Constraint>; +def OperandsBroadcastToOutputType : Constraint>; + def IsTailOfShape : Constraint>; @@ -293,15 +297,15 @@ def HaveSameType : Constraint>; // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { - def : Pat<(BinaryOp (TFL_TileOp $input, (ConstantOp $tile)), + def : Pat<(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), $operand, $act_func), (BinaryOp $input, $operand, $act_func), - [(AreBroadcastableTypes $input, $operand)]>; + [(OperandsBroadcastToOutputType $input, $operand, $result)]>; - def : Pat<(BinaryOp $operand, + def : Pat<(BinaryOp:$result $operand, (TFL_TileOp $input, (ConstantOp $tile)), $act_func), (BinaryOp $operand, $input, $act_func), - [(AreBroadcastableTypes $operand, $input)]>; + [(OperandsBroadcastToOutputType $operand, $input, $result)]>; } // Multi-pattern consisting of matching stand-alone op or op followed by relu. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 6cccdf5aa8d..b2cc58b863a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -23,7 +23,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project @@ -35,6 +34,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project @@ -141,10 +141,7 @@ LogicalResult CheckOutputConsumer( for (int i = 0; i < expected_num_outputs; ++i) { auto it = expected_consumer_indices.find(i); - if (it != expected_consumer_indices.end()) { - // Expected consumer. - if (call_op->getResult(i).use_empty()) return failure(); - } else { + if (it == expected_consumer_indices.end()) { // Unexpected consumer. if (!call_op->getResult(i).use_empty()) return failure(); } @@ -160,8 +157,9 @@ LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) { if (call_op && op->getAttrOfType("f").getRootReference() == lstm_func.getName()) { // Keras LSTM have 5 outputs. - // We should make sure only the second output is consumed. - if (failed(CheckOutputConsumer(call_op, 5, {1}))) check_failed = true; + // We should make sure only the first or the second output are consumed. + if (failed(CheckOutputConsumer(call_op, 5, {0, 1}))) + check_failed = true; } }); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index ef6fd1899d2..7592f462f6b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -195,7 +196,7 @@ using PreparePerChannelFakeQuant = template struct ConvertTFConvOp : public RewritePattern { // Transient state for preserving data from match to rewrite - struct ConvertTFConvOpMatchState : public PatternState { + struct ConvertTFConvOpMatchState { IntegerAttr dilation_height_factor; IntegerAttr dilation_width_factor; StringAttr padding; @@ -207,7 +208,8 @@ struct ConvertTFConvOp : public RewritePattern { : RewritePattern(TFConvOpType::getOperationName(), 1, context), intAttrOne(Builder(context).getI32IntegerAttr(1)) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { // Assumes TensorFlow convolution op is already verified to be // in valid form. @@ -226,38 +228,29 @@ struct ConvertTFConvOp : public RewritePattern { IntegerAttr height, width; if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure(); - auto state = std::make_unique(); - - state->stride_height = height; - state->stride_width = width; + ConvertTFConvOpMatchState state; + state.stride_height = height; + state.stride_width = width; if (TFIntListIs1XY1(op, "dilations", &height, &width)) { - state->dilation_height_factor = height; - state->dilation_width_factor = width; + state.dilation_height_factor = height; + state.dilation_width_factor = width; } else { // If the 'dilations' attribute is missing, we use the default value (1) // for both dilation height and width factor. - state->dilation_height_factor = intAttrOne; - state->dilation_width_factor = intAttrOne; + state.dilation_height_factor = intAttrOne; + state.dilation_width_factor = intAttrOne; } - StringAttr padding_attr; - if (!TFPaddingIsSameOrValid(op, &padding_attr)) return matchFailure(); - state->padding = padding_attr; + if (!TFPaddingIsSameOrValid(op, &state.padding)) return matchFailure(); // Additionally, we require the filter operand to be of 4-D tensor type so // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). - auto filter_type = - tf_op.filter().getType().template dyn_cast(); - if (filter_type && filter_type.getRank() == 4) - return matchSuccess(std::move(state)); + auto filter = tf_op.filter(); + auto filter_type = filter.getType().template dyn_cast(); + if (!filter_type || filter_type.getRank() != 4) return matchFailure(); - return matchFailure(); - } - - void rewrite(Operation *op, std::unique_ptr state, - PatternRewriter &rewriter) const override { // TensorFlow convolution op only has two inputs, while the TFLite one has // three, with the bias vector marked as optional. However, TOCO has a // dedicated pass, EnsureBiasVectors, to create default bias vectors for all @@ -267,11 +260,7 @@ struct ConvertTFConvOp : public RewritePattern { // TODO(antiagainst): also handle the case of tf.Add(tf.[op], ) - TFConvOpType tf_op = cast(op); - // Get a splat zero tensor with the expected dimension for the bias tensor - auto filter = tf_op.filter(); - auto filter_type = filter.getType().template cast(); auto elem_type = filter_type.getElementType(); auto bias_dim = static_cast(this)->getBiasDim( filter_type.getShape()); @@ -280,12 +269,12 @@ struct ConvertTFConvOp : public RewritePattern { auto bias = rewriter.create(op->getLoc(), bias_type, bias_attr); - auto *conv_state = static_cast(state.get()); auto conv_op = static_cast(this)->createTFLOp( - conv_state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), - filter, bias); + &state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), filter, + bias); rewriter.replaceOp(op, conv_op.getResult()); + return matchSuccess(); } const IntegerAttr intAttrOne; @@ -655,8 +644,8 @@ void PrepareTFPass::runOnFunction() { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); } - patterns.insert(ctx); + patterns.insert(ctx); applyPatternsGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td index b0435b7cf4c..6943b9c03e4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td @@ -26,3 +26,8 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def ConvertTensorListFromTensor : Pat< (TF_TensorListFromTensorOp $tensor, $element_shape), (replaceWithValue $tensor)>; + +// This pattern is in PrepareTF pass and added here temporary +// TODO(karimnosseir): Move away from here after looking in ordering +// the passes. +def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 8ed5b0e0341..be024eccd45 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -98,7 +98,6 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { extern_values.insert(extern_value); continue; } - assert(extern_value.getDefiningOp()->hasNoSideEffect()); if (!const_none) { // Add constant at start of region. auto const_builder = diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 85bd6a18764..7158d634a89 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -38,7 +38,7 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { case tflite::TensorType_INT32: return builder.getIntegerType(32); case tflite::TensorType_UINT8: - return mlir::TF::Uint8Type::get(builder.getContext()); + return builder.getIntegerType(8, /*isSigned=*/false); case tflite::TensorType_INT64: return builder.getIntegerType(64); case tflite::TensorType_STRING: diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 398433ca996..a138812e54d 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -73,21 +73,29 @@ Value CreateI64DenseConst(OpBuilder* builder, ArrayRef shape, return builder->create(location, type, attr); } +Value CreateI32DenseConst(OpBuilder* builder, ArrayRef values, + mlir::Location location) { + auto type = RankedTensorType::get(static_cast(values.size()), + builder->getIntegerType(32)); + auto attr = DenseElementsAttr::get(type, values); + return builder->create(location, type, attr); +} + Value CreateNoneValue(OpBuilder* builder, mlir::Location location) { return builder->create(location, builder->getNoneType(), builder->getUnitAttr()); } Value Transpose(OpBuilder* builder, Value value_to_transpose, - SmallVector perm, RankedTensorType original_type, + SmallVector perm, RankedTensorType original_type, mlir::Location location) { // Create a constant op for transpose permutation. - auto perm_op = CreateI64DenseConst(builder, perm, perm, location); + auto perm_op = CreateI32DenseConst(builder, perm, location); // Create tensor type for the transpose result. auto transpose_type = original_type; auto transpose_shape = functional::map( - [transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); }, + [transpose_type](int32_t dim) { return transpose_type.getDimSize(dim); }, perm); auto elem_type = transpose_type.getElementType(); auto result_type = RankedTensorType::get(transpose_shape, elem_type); @@ -99,7 +107,7 @@ Value Transpose(OpBuilder* builder, Value value_to_transpose, Value Transpose2D(OpBuilder* builder, Value value_to_transpose, RankedTensorType type, mlir::Location location) { // Create a constant op for transpose permutation. - SmallVector perm = {1, 0}; + SmallVector perm = {1, 0}; return Transpose(builder, value_to_transpose, perm, type, location); } @@ -148,6 +156,27 @@ Value SliceRankedTensor(OpBuilder* builder, Value input, input, slice_i2c_begin, slice_i2c_size); } +Value CreateStridedSliceOp(mlir::Location loc, ArrayRef output_shape, + Value input, ArrayRef begin, + ArrayRef end, ArrayRef strides, + int64_t begin_mask, int64_t end_mask, + int64_t ellipsis_mask, int64_t new_axis_mask, + int64_t shrink_axis_mask, OpBuilder* builder) { + auto output_type = RankedTensorType::get( + output_shape, input.getType().cast().getElementType()); + auto begin_tensor = CreateI32DenseConst(builder, begin, loc); + auto end_tensor = CreateI32DenseConst(builder, end, loc); + auto strides_tensor = CreateI32DenseConst(builder, strides, loc); + + return builder->create( + loc, output_type, input, begin_tensor, end_tensor, strides_tensor, + builder->getI64IntegerAttr(begin_mask), + builder->getI64IntegerAttr(end_mask), + builder->getI64IntegerAttr(ellipsis_mask), + builder->getI64IntegerAttr(new_axis_mask), + builder->getI64IntegerAttr(shrink_axis_mask)); +} + } // namespace void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() { @@ -386,7 +415,12 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { forget_layer_norm_coefficients_, cell_layer_norm_coefficients_, output_layer_norm_coefficients_, builder_.getStringAttr("TANH"), builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0), - builder_.getStringAttr("FULL")); + builder_.getStringAttr("FULL"), + /*input_to_input_intermediate=*/mlir::TypeAttr(), + /*input_to_forget_intermediate=*/mlir::TypeAttr(), + /*input_to_cell_intermediate=*/mlir::TypeAttr(), + /*input_to_output_intermediate=*/mlir::TypeAttr(), + /*effective_hidden_scale_intermediate=*/mlir::TypeAttr()); // Cast the static shaped lstm result to FuncOp's signature - // Ranked but unknown 2nd dimension to support stacking these. @@ -588,16 +622,6 @@ LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, return success(); } -void UpdateFuncSignature(int batch, int time, int output, - mlir::FuncOp* func_op) { - SmallVector output_shape{batch, time, output}; - auto input_types = func_op->getType().getInputs(); - auto element_type = input_types[0].cast().getElementType(); - auto output_type = mlir::RankedTensorType::get(output_shape, element_type); - func_op->setType( - mlir::FunctionType::get(input_types, output_type, func_op->getContext())); -} - // TODO(b/147436982): Consider refactor this to be more general. LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { // For argument order, please check out standard_lstm under @@ -626,26 +650,21 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { auto final_inputs = input; auto final_input_type = input_type; - // We will transpose the inputs. - if (!time_majored) { - SmallVector perm = {1, 0, 2}; - final_inputs = - Transpose(builder, final_inputs, perm, input_type, func_op.getLoc()); - final_input_type = final_inputs.getType().dyn_cast(); - } // Handle go_backwards: // LSTM in Keras semantic will reverse the input sequence if it's go_backwards auto go_backwards_attr = func_op.getAttrOfType("tf.go_backwards"); if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) { - // We assume input is already in {time, batch, size} layout. - final_inputs = - Reverse(builder, final_inputs, 0, final_input_type, func_op.getLoc()); + int time_dim = time_majored ? 0 : 1; + final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type, + func_op.getLoc()); } - int batch = final_input_type.getDimSize(1); - int time = final_input_type.getDimSize(0); + int batch = time_majored ? final_input_type.getDimSize(1) + : final_input_type.getDimSize(0); + int time = time_majored ? final_input_type.getDimSize(0) + : final_input_type.getDimSize(1); // Setup correct weights. RankedTensorType weight_type = @@ -686,14 +705,20 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { return failure(); // Build the lstm op. - SmallVector output_shape = {time, batch, n_output}; + SmallVector output_shape; + if (time_majored) { + output_shape = {time, batch, n_output}; + } else { + output_shape = {batch, time, n_output}; + } auto result_type = mlir::RankedTensorType::get( - output_shape, input.getType().cast().getElementType()); + output_shape, + final_inputs.getType().cast().getElementType()); Value none = builder->create( func_op.getLoc(), builder->getNoneType(), builder->getUnitAttr()); auto lstm = builder->create( - func_op.getLoc(), result_type, /*input=*/input, + func_op.getLoc(), result_type, /*input=*/final_inputs, /*input_to_input_weights=*/weights_array->getResult(0), /*input_to_forget_weights=*/weights_array->getResult(1), /*input_to_cell_weights=*/weights_array->getResult(2), @@ -718,29 +743,80 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { /*cell_layer_norm_coefficients=*/none, /*output_layer_norm_coefficients=*/none, builder->getStringAttr("TANH"), builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0), - builder->getBoolAttr(true)); + builder->getBoolAttr(time_majored)); - auto final_output = lstm.getResult(); - if (!time_majored) { - SmallVector perm = {1, 0, 2}; - final_output = - Transpose(builder, final_output, perm, result_type, func_op.getLoc()); + auto final_output_full_sequences = lstm.getResult(); + + // Populate the last output: last output is sliced from the full sequences. + // If time_major: last_output = outputs[-1, :, :] + // else: last_output = outputs[:, -1, :] + // + // As we are creating the strided_slice op, we need to populate the following + // fields: + // end: should always be (0, 0, 0) + // strides: should always be (1, 1, 1) + // begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored. + // new_axis_mask: should always be 0. + // ellipsis_mask: should always be 0. + // begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's + // time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's + // time-majored. + SmallVector last_output_shape({batch, n_output}); + + SmallVector end({0, 0, 0}); + SmallVector strides({1, 1, 1}); + SmallVector begin; + + int64_t new_axis_mask = 0; + int64_t ellipsis_mask = 0; + int64_t begin_mask; + int64_t end_mask; + int64_t shrink_axis_mask; + if (time_majored) { + begin_mask = 6; + end_mask = 6; + shrink_axis_mask = 1; + begin = {-1, 0, 0}; + } else { + begin_mask = 5; + end_mask = 5; + shrink_axis_mask = 2; + begin = {0, -1, 0}; } + auto last_output = CreateStridedSliceOp( + func_op.getLoc(), last_output_shape, final_output_full_sequences, begin, + end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, + shrink_axis_mask, builder); + SmallVector outputs; + SmallVector output_types; - for (int i = 0; i < 5; ++i) { - if (i == 1) { - // only this one is the real output. - outputs.push_back(final_output); - } else { - auto result_type = - func_op.getCallableResults()[i].dyn_cast(); - outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, - func_op.getLoc())); - } + // Due to the existence of the while loop, the timestamp may be unknown + // for the signature, for us, since we know the inputs, we can infer the time + // steps. + + // Last output. + outputs.push_back(last_output); + output_types.push_back(last_output.getType()); + + // Full sequences. + outputs.push_back(final_output_full_sequences); + output_types.push_back(final_output_full_sequences.getType()); + + // All the rest: states, device. + for (int i = 2; i < 5; ++i) { + auto result_type = + func_op.getCallableResults()[i].dyn_cast(); + outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, + func_op.getLoc())); + output_types.push_back(result_type); } + // Update function signatures. + func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(), + output_types, func_op.getContext())); + builder->create(func_op.getLoc(), outputs); return success(); } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc new file mode 100644 index 00000000000..e554686531a --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_os_ostream.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +static inline absl::string_view StringRefToView(llvm::StringRef ref) { + return {ref.data(), ref.size()}; +} + +// Dumps the MLIR module to disk. +// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can +// be created). +static void DumpModule(mlir::ModuleOp module, std::string file_prefix) { + std::string prefix = GetDumpDirFromEnvVar(); + if (prefix.empty()) return; + + auto* env = tensorflow::Env::Default(); + auto status = env->RecursivelyCreateDir(prefix); + if (!status.ok()) { + LOG(WARNING) << "cannot create directory '" + prefix + + "': " + status.error_message(); + return; + } + + prefix += "/" + file_prefix; + if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) { + LOG(WARNING) << "cannot create unique filename, won't dump MLIR module."; + return; + } + + std::unique_ptr file_writer; + status = env->NewWritableFile(prefix, &file_writer); + if (!status.ok()) { + LOG(WARNING) << "cannot open file '" + prefix + + "': " + status.error_message(); + return; + } + + // Print the module to a string before writing to the file. + std::string txt_module; + { + llvm::raw_string_ostream os(txt_module); + module.print(os); + } + + status = file_writer->Append(txt_module); + if (!status.ok()) { + LOG(WARNING) << "error writing to file '" + prefix + + "': " + status.error_message(); + return; + } + (void)file_writer->Close(); + VLOG(1) << "Dumped MLIR module to " << prefix; +} + +MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { + static auto* global = new MlirOptimizationPassRegistry(); + return *global; +} + +Status MlirFunctionOptimizationPass::Run( + const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) { + // Skip conversion from Graph to MLIR if none of the passes are enabled. + const bool is_enabled = + llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool { + return pass_registration.pass->IsEnabled(config_proto); + }); + + if (!is_enabled) { + VLOG(1) << "None of the MLIR optimization passes are enabled " + << "(registered " << registry_->passes().size() << ")"; + return Status::OK(); + } + + VLOG(1) << "Running MLIR Graph Optimization Passes " + << "(registered " << registry_->passes().size() << " passes)"; + + GraphDebugInfo debug_info; + mlir::MLIRContext context; + GraphImportConfig import_config; + import_config.graph_as_function = true; + import_config.control_outputs = *control_ret_node_names; + TF_ASSIGN_OR_RETURN(auto module_ref, + ConvertGraphToMlir(**graph, debug_info, *flib_def, + import_config, &context)); + + AddDevicesToOp(*module_ref, &device_set); + + for (auto& pass_registration : registry_->passes()) { + llvm::StringRef name = pass_registration.pass->name(); + VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name)); + } + + TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref)); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name)); + } + } + + GraphExportConfig export_config; + export_config.graph_as_function = true; + absl::flat_hash_set control_ret_nodes; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirToGraph(*module_ref, export_config, graph, flib_def, + &control_ret_nodes), + "Error converting MLIR module back to graph"); + + control_ret_node_names->clear(); + control_ret_node_names->reserve(control_ret_nodes.size()); + for (const auto* node : control_ret_nodes) + control_ret_node_names->push_back(node->name()); + + *control_rets_updated = true; + + return Status::OK(); +} + +MlirV1CompatOptimizationPassRegistry& +MlirV1CompatOptimizationPassRegistry::Global() { + static auto* global = new MlirV1CompatOptimizationPassRegistry(); + return *global; +} + +Status MlirV1CompatGraphOptimizationPass::Run( + const GraphOptimizationPassOptions& options) { + // Skip function graphs as MlirOptimizationPassRegistry_ will be used instead. + if (options.is_function_graph) return Status::OK(); + + // Skip conversion from Graph to MLIR if none of the passes are enabled. + const bool is_enabled = + absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool { + return pass_registration.pass->IsEnabled( + options.session_options->config); + }); + + if (!is_enabled) { + VLOG(1) << "None of the MLIR optimization passes are enabled " + << "(registered" << registry_->passes().size() << " passes)"; + return Status::OK(); + } + + VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes " + << "(registered" << registry_->passes().size() << " passes)"; + + GraphDebugInfo debug_info; + mlir::MLIRContext context; + GraphImportConfig import_config; + import_config.upgrade_legacy = true; + TF_ASSIGN_OR_RETURN( + auto module_ref, + ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, + import_config, &context)); + + AddDevicesToOp(*module_ref, options.device_set); + + for (auto& pass_registration : registry_->passes()) { + llvm::StringRef name = pass_registration.pass->name(); + VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name)); + } + + TF_RETURN_IF_ERROR(pass_registration.pass->Run(options, *module_ref)); + + if (VLOG_IS_ON(1)) { + DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name)); + } + } + + GraphExportConfig export_config; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirToGraph(*module_ref, export_config, options.graph, + options.flib_def), + "Error converting MLIR module back to graph"); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h new file mode 100644 index 00000000000..aed5307d39d --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ + +#include "mlir/IR/Module.h" // TF:llvm-project +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow function graphs (Tensorflow V2). +// -------------------------------------------------------------------------- // + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for function graphs built by Tensorflow V2 and +// instantiated by the process_function_library_runtime (see +// FunctionOptimizationPass for details). +class MlirOptimizationPass { + public: + virtual ~MlirOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + virtual bool IsEnabled(const ConfigProto& config_proto) const = 0; + + virtual Status Run(const ConfigProto& config_proto, + mlir::ModuleOp module) = 0; +}; + +class MlirOptimizationPassRegistry { + public: + struct PassRegistration { + int priority; + std::unique_ptr pass; + }; + + struct PriorityComparator { + bool operator()(const PassRegistration& x, + const PassRegistration& y) const { + return x.priority < y.priority; + } + }; + + using Passes = std::set; + + // Returns the global registry of MLIR optimization passes. + static MlirOptimizationPassRegistry& Global(); + + void Add(int priority, std::unique_ptr pass) { + passes_.insert({priority, std::move(pass)}); + } + + const Passes& passes() const { return passes_; } + + private: + Passes passes_; +}; + +// Function optimization pass that runs all MLIR passes registered in +// MlirOptimizationPassRegistry. +class MlirFunctionOptimizationPass : public FunctionOptimizationPass { + public: + explicit MlirFunctionOptimizationPass( + const MlirOptimizationPassRegistry* registry = + &MlirOptimizationPassRegistry::Global()) + : registry_(registry) {} + + Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override; + + private: + const MlirOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow V1 graphs. +// -------------------------------------------------------------------------- // + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for V1 graphs (legacy graphs) executed via Session +// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g. +// it raises control flow from Switch/Merge nodes to functional control flow +// with If/While operations). +class MlirV1CompatOptimizationPass { + public: + virtual ~MlirV1CompatOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + virtual bool IsEnabled(const ConfigProto& config_proto) const = 0; + + virtual Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) = 0; +}; + +class MlirV1CompatOptimizationPassRegistry { + public: + struct PassRegistration { + int priority; + std::unique_ptr pass; + }; + + struct PriorityComparator { + bool operator()(const PassRegistration& x, + const PassRegistration& y) const { + return x.priority < y.priority; + } + }; + + using Passes = std::set; + + // Returns the global registry of MLIR optimization passes. + static MlirV1CompatOptimizationPassRegistry& Global(); + + void Add(int priority, std::unique_ptr pass) { + passes_.insert({priority, std::move(pass)}); + } + + const Passes& passes() const { return passes_; } + + private: + Passes passes_; +}; + +class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass { + public: + explicit MlirV1CompatGraphOptimizationPass( + const MlirV1CompatOptimizationPassRegistry* registry = + &MlirV1CompatOptimizationPassRegistry::Global()) + : registry_(registry) {} + + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + const MlirV1CompatOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// Helper classes for static registration of MLIR (V1 Compat) passes in the +// corresponding registry. +// -------------------------------------------------------------------------- // + +namespace mlir_pass_registration { + +class MlirOptimizationPassRegistration { + public: + explicit MlirOptimizationPassRegistration( + int priority, std::unique_ptr pass) { + MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass)); + } +}; + +class MlirV1CompatOptimizationPassRegistration { + public: + explicit MlirV1CompatOptimizationPassRegistration( + int priority, std::unique_ptr pass) { + MlirV1CompatOptimizationPassRegistry::Global().Add(priority, + std::move(pass)); + } +}; + +} // namespace mlir_pass_registration + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc new file mode 100644 index 00000000000..8155af6505e --- /dev/null +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_registration.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +static function_optimization_registration::FunctionOptimizationPassRegistration + register_mlir_passes(std::make_unique()); + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, + MlirV1CompatGraphOptimizationPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 3a308e2e9d2..67533197f3e 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -59,6 +59,9 @@ if platform.system() == 'Windows': else: llvm_config.use_default_substitutions() +llvm_config.config.substitutions.append( + ('%tfrt_bindir', 'tensorflow/compiler/aot')) + # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) @@ -68,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [ tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner' + 'mlir-tflite-runner', 'tfcompile' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index b324386662e..6c369a5a24c 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -45,6 +45,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/xla', + 'tensorflow/compiler/aot' ] config.mlir_tf_tools_dirs = [ os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d52fd0c3b72..e2aae0ec52e 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -10,6 +10,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ + "//learning/pathways/data_parallel/tf2xla/...", "//tensorflow/compiler/...", "//tensorflow/lite/experimental/tf_runtime/...", "//tensorflow/python/...", @@ -24,7 +25,8 @@ filegroup( "ir/tf_op_interfaces.td", "ir/tf_ops.td", "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", ], ) @@ -62,6 +64,14 @@ gentbl( "-gen-op-doc", "g3doc/tf_ops.md", ), + ( + "-gen-struct-attr-decls", + "ir/tf_structs.h.inc", + ), + ( + "-gen-struct-attr-defs", + "ir/tf_structs.cc.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", @@ -186,6 +196,7 @@ cc_library( "ir/tf_ops.cc.inc", "ir/tf_ops.h.inc", "ir/tf_saved_model.cc", + "ir/tf_structs.cc", "ir/tf_verifiers.cc", ], hdrs = [ @@ -194,12 +205,14 @@ cc_library( "ir/tf_executor.h", "ir/tf_ops.h", "ir/tf_saved_model.h", + "ir/tf_structs.h", "ir/tf_traits.h", "ir/tf_verifiers.h", "transforms/bridge.h", + "transforms/einsum.h", "transforms/passes.h", "transforms/unroll_batch_matmul.h", - "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.h", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], includes = ["include"], @@ -219,10 +232,12 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:CallOpInterfacesIncGen", + "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -264,23 +279,49 @@ cc_library( ], ) +cc_library( + name = "unroll_batch_matmul_pass", + srcs = [ + "transforms/unroll_batch_matmul.cc", + ], + hdrs = [ + "transforms/unroll_batch_matmul.h", + ], + deps = [ + ":tensorflow", + "//tensorflow/core:framework", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tensorflow_passes", srcs = [ "transforms/annotate_parameter_replication.cc", + "transforms/batchmatmul_to_einsum.cc", "transforms/bridge.cc", "transforms/bridge_pass.cc", "transforms/cluster_formation.cc", "transforms/cluster_outlining.cc", + "transforms/collection_ops_util.cc", "transforms/decompose_resource_ops_pass.cc", + "transforms/einsum.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_outline_tpu_island.cc", "transforms/fold_switch.cc", + "transforms/freeze_global_tensors.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/generated_canonicalize.inc", "transforms/generated_optimize.inc", + "transforms/gpu_fusion.cc", "transforms/graph_pruning.cc", "transforms/launch_to_device_attribute.cc", "transforms/layout_optimization.cc", @@ -299,6 +340,7 @@ cc_library( "transforms/shape_inference_pass.cc", "transforms/sink_constant.cc", "transforms/stack_ops_decomposition.cc", + "transforms/tensor_list_ops_decomposition.cc", "transforms/test_side_effect_analysis.cc", "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", @@ -308,17 +350,18 @@ cc_library( "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_variable_runtime_reformatting.cc", - "transforms/unroll_batch_matmul.cc", "translate/breakup-islands.cc", "translate/control_to_executor_dialect.cc", "translate/executor_to_control_dialect.cc", "translate/tf_functional_to_executor.cc", ], hdrs = [ + "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", + "transforms/collection_ops_util.h", + "transforms/einsum.h", "transforms/passes.h", "transforms/shape_inference.h", - "transforms/unroll_batch_matmul.h", ], includes = ["include"], deps = [ @@ -337,6 +380,8 @@ cc_library( ":tensorflow_types", ":tpu_rewrite_device_util", ":translate_utils", + ":unroll_batch_matmul_pass", + ":xla_sharding_util", "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", @@ -355,7 +400,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -378,13 +422,34 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "graph_optimization_pass", + srcs = ["transforms/graph_optimization_pass.cc"], + hdrs = ["transforms/graph_optimization_pass.h"], + deps = [ + ":tensorflow_passes", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + ], + alwayslink = 1, +) + +cc_library( + name = "graph_optimization_pass_registration", + srcs = ["transforms/graph_optimization_pass_registration.cc"], + deps = [ + ":graph_optimization_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", + ], + alwayslink = 1, +) + # Library with TensorFlow dialect static initialization. cc_library( name = "tensorflow_dialect_registration", srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", - "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:LoopOpsTransforms", ], @@ -703,18 +768,34 @@ cc_library( ) cc_library( - name = "tf_dialect_passes", + name = "decode_constant_pass", srcs = [ - "transforms/constant_fold.cc", "transforms/decode_constant.cc", - "transforms/dialect_hooks.cc", ], hdrs = [ - "transforms/constant_fold.h", "transforms/decode_constant.h", ], deps = [ ":convert_tensor", + ":tensorflow", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_dialect_passes", + srcs = [ + "transforms/constant_fold.cc", + "transforms/dialect_hooks.cc", + ], + hdrs = [ + "transforms/constant_fold.h", + ], + deps = [ + ":convert_tensor", + ":decode_constant_pass", ":eval_util", ":tensorflow", ":tensorflow_types", @@ -725,9 +806,8 @@ cc_library( "//tensorflow/stream_executor", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], alwayslink = 1, @@ -738,6 +818,7 @@ cc_library( deps = [ ":tensorflow_dialect_registration", ":tf_dialect_passes", + "@llvm-project//mlir:AllPassesAndDialects", ], ) @@ -880,7 +961,8 @@ tf_native_cc_binary( genrule( name = "derived_attr_populator_inc", srcs = [ - "@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", @@ -931,7 +1013,6 @@ cc_library( ":error_util", ":tensorflow_dialect_registration", ":tensorflow_passes", - ":tf_dialect_passes", ":translate_utils", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", @@ -943,6 +1024,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", @@ -1043,7 +1125,7 @@ cc_library( srcs = ["utils/tpu_rewrite_device_util.cc"], hdrs = ["utils/tpu_rewrite_device_util.h"], deps = [ - "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/core:framework", @@ -1074,6 +1156,7 @@ cc_library( srcs = ["utils/device_util.cc"], hdrs = ["utils/device_util.h"], deps = [ + ":tensorflow", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "@llvm-project//llvm:support", @@ -1154,3 +1237,20 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "xla_sharding_util", + srcs = [ + "utils/xla_sharding_util.cc", + ], + hdrs = [ + "utils/xla_sharding_util.h", + ], + deps = [ + ":tensorflow", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 84c3cd64a5f..931f24b9606 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -321,7 +321,13 @@ bool OpIsDeclaration(Operation* op, // Returns if `op` is know to not have any side effect. bool OpIsKnownToHaveNoSideEffect(Operation* op) { - if (op->hasNoSideEffect()) return true; + // TODO(riverriddle) We shouldn't treat all terminator operations as having + // side effects, this should be relaxed. + // TODO(riverriddle) Properly handle region side effects. + if (MemoryEffectOpInterface::hasNoEffect(op) && op->isKnownNonTerminator() && + op->getNumRegions() == 0) { + return true; + } if (auto if_op = llvm::dyn_cast(op)) { return if_op.is_stateless(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 59a1cc21b28..0156d7e7e9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // TF:llvm-project #include "mlir/IR/OpDefinition.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project namespace mlir { namespace TFControlFlow { @@ -84,7 +85,7 @@ class TFControlType : public Type::TypeBase { // Note: Additional result corresponds to the control output. class EnterOp : public Op::Impl, - OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> { + OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> { public: using Op::Op; @@ -94,6 +95,9 @@ class EnterOp void setData(Value value) { setOperand(0, value); } LogicalResult verify(); + + // EnterOp has no side-effects. + void getEffects(SmallVectorImpl &) {} }; // The "_tf.Merge" operation takes a list of input operands and returns a value @@ -197,7 +201,7 @@ class NextIterationSinkOp // Note: Additional result corresponds to the control output. class LoopCondOp : public Op::Impl, - OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> { + OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> { public: using Op::Op; static StringRef getOperationName() { return "_tf.LoopCond"; } @@ -206,6 +210,9 @@ class LoopCondOp void setData(Value value) { setOperand(0, value); } LogicalResult verify(); + + // LoopCondOp has no side-effects. + void getEffects(SmallVectorImpl &) {} }; // The "_tf.Switch" operation takes a data operand and a boolean predicate @@ -260,8 +267,9 @@ class SwitchOp : public Op::Impl, // (tensor<*xi32>, !_tf.control) // // Note: Additional result corresponds to the control output. -class ExitOp : public Op::Impl, - OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> { +class ExitOp + : public Op::Impl, + OpTrait::NResults<2>::Impl, MemoryEffectOpInterface::Trait> { public: using Op::Op; static StringRef getOperationName() { return "_tf.Exit"; } @@ -270,6 +278,9 @@ class ExitOp : public Op::Impl, void setData(Value value) { setOperand(0, value); } LogicalResult verify(); + + // ExitOp has no side-effects. + void getEffects(SmallVectorImpl &) {} }; } // namespace TFControlFlow diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index ccab3d9c6e7..ac468d9810c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/InitAllDialects.h" // TF:llvm-project -#include "mlir/InitAllPasses.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -23,13 +21,6 @@ limitations under the License. namespace mlir { -static bool auto_init = []() { - registerAllDialects(); - registerAllPasses(); - - return true; -}(); - // Static initialization for TF dialect registration. static DialectRegistration tf_control_flow_ops; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5071910031f..38fb3154c48 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -85,6 +85,20 @@ struct TFInlinerInterface : public DialectInlinerInterface { /*truncate=*/builder.getBoolAttr(false)); } }; + +// Checks if a block wraps a single operation and the single operation results +// are perfectly forwarded to the block's terminator. +bool BlockWrapsSingleOp(Block* block) { + auto body = block->without_terminator(); + if (!has_single_element(body)) return false; + + Operation& wrapped_op = *body.begin(); + Operation* terminator = block->getTerminator(); + return wrapped_op.getNumResults() == terminator->getNumOperands() && + std::equal(wrapped_op.getResults().begin(), + wrapped_op.getResults().end(), + terminator->getOperands().begin()); +} } // end anonymous namespace TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) @@ -105,17 +119,7 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) // Checks if a tf_device.launch wraps a single operation and the single // operation results are perfectly forwarded to the launch return. -bool LaunchOp::WrapsSingleOp() { - auto body = GetBody().without_terminator(); - if (!has_single_element(body)) return false; - - Operation& wrapped_op = *body.begin(); - Operation* terminator = GetBody().getTerminator(); - return wrapped_op.getNumResults() == terminator->getNumOperands() && - std::equal(wrapped_op.getResults().begin(), - wrapped_op.getResults().end(), - terminator->getOperands().begin()); -} +bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); } //===----------------------------------------------------------------------===// // tf_device.return @@ -210,30 +214,32 @@ void ParallelExecuteOp::build(Builder* builder, OperationState& state, state.addTypes(output_types); } -std::vector ParallelExecuteOp::GetRegionOutputs( - unsigned region_index) { - int num_region_results = - GetRegionBlockWithIndex(region_index).getTerminator()->getNumResults(); - std::vector results; - results.reserve(num_region_results); - - int return_value_offset = 0; - for (int region_id = 0; region_id < region_index; ++region_id) - return_value_offset += - GetRegionBlockWithIndex(region_id).getTerminator()->getNumResults(); - - for (int i = 0; i < num_region_results; ++i) - results.emplace_back(getOperation()->getOpResult(return_value_offset + i)); - - return results; -} - LogicalResult ParallelExecuteOp::verify() { return Verify(*this); } Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) { return getOperation()->getRegion(index).front(); } +Operation::result_range ParallelExecuteOp::GetRegionOutputs( + unsigned region_index) { + int num_region_results = + GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands(); + + int return_value_offset = 0; + for (int region_id = 0; region_id < region_index; ++region_id) + return_value_offset += + GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands(); + + Operation::result_range region_results(getOperation(), + /*startIndex=*/return_value_offset, + /*count=*/num_region_results); + return region_results; +} + +bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) { + return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index)); +} + //===----------------------------------------------------------------------===// // tf_device.replicate //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 0cb26bbfe65..1b20120cc2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -74,9 +74,14 @@ class ParallelExecuteOp static StringRef getOperationName() { return "tf_device.parallel_execute"; } - std::vector GetRegionOutputs(unsigned region_index); LogicalResult verify(); Block& GetRegionBlockWithIndex(unsigned index); + Operation::result_range GetRegionOutputs(unsigned region_index); + + // Checks if a tf_device.parallel_execute index'th region block wraps a single + // operation and the single operation results are perfectly forwarded to the + // region block's return. + bool RegionWrapsSingleOp(unsigned index); }; } // namespace tf_device diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 38f72f24bd1..3c47ef1117d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -128,7 +128,7 @@ def TfExecutor_GraphOp : TfExecutor_Op<"graph", def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfterAllData, HasParent<"GraphOp">]> { let summary = [{ - The `tf_executor.fetch` operation terminates the graph and returns values"; + The `tf_executor.fetch` operation terminates the graph and returns values; }]; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 440aeaa49dc..d2bbbd32b7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -319,6 +319,32 @@ this value or a subsequent newer value of the variable. TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; } +def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Computes arctangent of `y/x` element-wise, respecting signs of the arguments. + }]; + + let description = [{ +This is the angle \( \theta \in [-\pi, \pi] \) such that +\[ x = r \cos(\theta) \] +and +\[ y = r \sin(\theta) \] +where \(r = \sqrt(x^2 + y^2) \). + }]; + + let arguments = (ins + TF_FpTensor:$y, + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> { let summary = "Performs average pooling on the input."; @@ -426,6 +452,10 @@ about broadcasting TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let verifier = [{ + return Verify(*this); + }]; + let hasCanonicalizer = 1; } @@ -481,7 +511,7 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } -def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect, TF_LayoutSensitiveInterface]> { +def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; let description = [{ @@ -505,13 +535,6 @@ Broadcasting is supported, so `value` may have any number of dimensions. let verifier = [{ return Verify(*this); }]; - - let extraClassDeclaration = [{ - // TF_LayoutSensitiveInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } - LogicalResult UpdateDataFormat(StringRef data_format); - }]; } def TF_BiasAddGradOp : TF_Op<"BiasAddGrad", [NoSideEffect]> { @@ -1036,6 +1059,7 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. // TF_LayoutSensitiveInterface: SmallVector GetLayoutDependentArgs() { return {0}; } SmallVector GetLayoutDependentResults() { return {0}; } + StringRef GetOptimalLayout(const RuntimeDevices& devices); LogicalResult UpdateDataFormat(StringRef data_format); }]; } @@ -1184,6 +1208,56 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect]> { + let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; + + let description = [{ +By default, this op performs an inclusive cumsum, which means that the first +element of the input is identical to the first element of the output: + +```python +tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +performed instead: + +```python +tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +``` + +By setting the `reverse` kwarg to `True`, the cumsum is performed in the +opposite direction: + +```python +tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TF_I32OrI64Tensor:$axis, + + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; +} + def TF_DepthToSpaceOp : TF_Op<"DepthToSpace", [NoSideEffect]> { let summary = "DepthToSpace for tensors of type T."; @@ -4870,6 +4944,8 @@ operation. ); TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; + + let hasCanonicalizer = 1; } def TF_RealOp : TF_Op<"Real", [NoSideEffect, SameOperandsAndResultShape]> { @@ -5120,7 +5196,7 @@ Resize `images` to `size` using nearest neighbor interpolation. }]; let arguments = (ins - TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, I32Tensor:$size, DefaultValuedAttr:$align_corners, @@ -5128,7 +5204,7 @@ Resize `images` to `size` using nearest neighbor interpolation. ); let results = (outs - TensorOf<[F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$resized_images ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6659,9 +6735,9 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and // `begin_indices`, `end_indices`, and `strides` with their canonical // values, respectively. bool GetSlicedBoundRanges( - ::llvm::SmallVectorImpl *begin_indices, - ::llvm::SmallVectorImpl *end_indices, - ::llvm::SmallVectorImpl *strides); + ::llvm::SmallVectorImpl *slice_begin, + ::llvm::SmallVectorImpl *slice_end, + ::llvm::SmallVectorImpl *slice_stride); }]; } @@ -6708,10 +6784,10 @@ shape of `StridedSlice`'s `input`. // `begin_indices`, `end_indices`, and `strides` with their canonical // values, respectively. bool GetSlicedShapeAndBoundRanges( - ::llvm::SmallVectorImpl *shape, - ::llvm::SmallVectorImpl *begin_indices, - ::llvm::SmallVectorImpl *end_indices, - ::llvm::SmallVectorImpl *strides); + ::llvm::SmallVectorImpl *input_shape, + ::llvm::SmallVectorImpl *slice_begin, + ::llvm::SmallVectorImpl *slice_end, + ::llvm::SmallVectorImpl *slice_stride); }]; } @@ -7020,6 +7096,275 @@ is the corresponding input gradient. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_TensorArrayCloseV3Op : TF_Op<"TensorArrayCloseV3", []> { + let summary = "Delete the TensorArray from its resource container."; + + let description = [{ +This enables the user to close and release the resource in the middle +of a step/run. + }]; + + let arguments = (ins + TF_ResourceTensor:$handle + ); + + let results = (outs); +} + +def TF_TensorArrayConcatV3Op : TF_Op<"TensorArrayConcatV3", []> { + let summary = "Concat the elements from the TensorArray into value `value`."; + + let description = [{ +Takes `T` elements of shapes + + ``` + (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) + ``` + +and concatenates them into a Tensor of shape: + + ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` + +All elements must have the same shape (excepting the first dimension). + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + F32Tensor:$flow_in, + + DefaultValuedAttr:$element_shape_except0 + ); + + let results = (outs + TF_Tensor:$value, + I64Tensor:$lengths + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_TensorArrayGatherV3Op : TF_Op<"TensorArrayGatherV3", []> { + let summary = [{ +Gather specific elements from the TensorArray into output `value`. + }]; + + let description = [{ +All elements selected by `indices` must have the same shape. + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + I32Tensor:$indices, + F32Tensor:$flow_in, + + DefaultValuedAttr:$element_shape + ); + + let results = (outs + TF_Tensor:$value + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_TensorArrayGradV3Op : TF_Op<"TensorArrayGradV3", []> { + let summary = [{ +Creates a TensorArray for storing the gradients of values in the given handle. + }]; + + let description = [{ +If the given TensorArray gradient already exists, returns a reference to it. + +Locks the size of the original TensorArray by disabling its dynamic size flag. + +**A note about the input flow_in:** + +The handle flow_in forces the execution of the gradient lookup to occur +only after certain other operations have occurred. For example, when +the forward TensorArray is dynamically sized, writes to this TensorArray +may resize the object. The gradient TensorArray is statically sized based +on the size of the forward TensorArray when this operation executes. +Furthermore, the size of the forward TensorArray is frozen by this call. +As a result, the flow is used to ensure that the call to generate the gradient +TensorArray only happens after all writes are executed. + +In the case of dynamically sized TensorArrays, gradient computation should +only be performed on read operations that have themselves been chained via +flow to occur only after all writes have executed. That way the final size +of the forward TensorArray is known when this operation is called. + +**A note about the source attribute:** + +TensorArray gradient calls use an accumulator TensorArray object. If +multiple gradients are calculated and run in the same session, the multiple +gradient nodes may accidentally flow through the same accumulator TensorArray. +This double counts and generally breaks the TensorArray gradient flow. + +The solution is to identify which gradient call this particular +TensorArray gradient is being called in. This is performed by identifying +a unique string (e.g. "gradients", "gradients_1", ...) from the input +gradient Tensor's name. This string is used as a suffix when creating +the TensorArray gradient object here (the attribute `source`). + +The attribute `source` is added as a suffix to the forward TensorArray's +name when performing the creation / lookup, so that each separate gradient +calculation gets its own TensorArray accumulator. + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + F32Tensor:$flow_in, + + StrAttr:$source + ); + + let results = (outs + TF_ResourceTensor:$grad_handle, + F32Tensor:$flow_out + ); +} + +def TF_TensorArrayReadV3Op : TF_Op<"TensorArrayReadV3", []> { + let summary = "Read an element from the TensorArray into output `value`."; + + let description = [{ + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + I32Tensor:$index, + F32Tensor:$flow_in + ); + + let results = (outs + TF_Tensor:$value + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_TensorArrayScatterV3Op : TF_Op<"TensorArrayScatterV3", []> { + let summary = [{ +Scatter the data from the input value into specific TensorArray elements. + }]; + + let description = [{ +`indices` must be a vector, its length must match the first dim of `value`. + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + I32Tensor:$indices, + TF_Tensor:$value, + F32Tensor:$flow_in + ); + + let results = (outs + F32Tensor:$flow_out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + +def TF_TensorArraySizeV3Op : TF_Op<"TensorArraySizeV3", []> { + let summary = "Get the current size of the TensorArray."; + + let description = [{ + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + F32Tensor:$flow_in + ); + + let results = (outs + I32Tensor:$size + ); +} + +def TF_TensorArraySplitV3Op : TF_Op<"TensorArraySplitV3", []> { + let summary = [{ +Split the data from the input value into TensorArray elements. + }]; + + let description = [{ +Assuming that `lengths` takes on values + + ```(n0, n1, ..., n(T-1))``` + +and that `value` has shape + + ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, + +this splits values into a TensorArray with T tensors. + +TensorArray index t will be the subtensor of values with starting position + + ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` + +and having size + + ```nt x d0 x d1 x ...``` + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + TF_Tensor:$value, + I64Tensor:$lengths, + F32Tensor:$flow_in + ); + + let results = (outs + F32Tensor:$flow_out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + +def TF_TensorArrayV3Op : TF_Op<"TensorArrayV3", []> { + let summary = "An array of Tensors of given size."; + + let description = [{ +Write data via Write and read via Read or Pack. + }]; + + let arguments = (ins + I32Tensor:$size, + + TypeAttr:$dtype, + DefaultValuedAttr:$element_shape, + DefaultValuedAttr:$dynamic_size, + DefaultValuedAttr:$clear_after_read, + DefaultValuedAttr:$identical_element_shapes, + StrAttr:$tensor_array_name + ); + + let results = (outs + TF_ResourceTensor:$handle, + F32Tensor:$flow + ); +} + +def TF_TensorArrayWriteV3Op : TF_Op<"TensorArrayWriteV3", []> { + let summary = "Push an element onto the tensor_array."; + + let description = [{ + }]; + + let arguments = (ins + TF_ResourceTensor:$handle, + I32Tensor:$index, + TF_Tensor:$value, + F32Tensor:$flow_in + ); + + let results = (outs + F32Tensor:$flow_out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_TensorListConcatV2Op : TF_Op<"TensorListConcatV2", [NoSideEffect]> { let summary = "Concats all tensors in the list along the 0th dimension."; @@ -7052,6 +7397,27 @@ lengths: Output tensor containing sizes of the 0th dimension of tensors in the l TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_TensorListElementShapeOp : TF_Op<"TensorListElementShape", [NoSideEffect]> { + let summary = "The shape of the elements of the given list, as a tensor."; + + let description = [{ +input_handle: the list + element_shape: the shape of elements of the list + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle + ); + + let results = (outs + TF_I32OrI64Tensor:$element_shape + ); + + TF_DerivedResultTypeAttr shape_type = TF_DerivedResultTypeAttr<0>; + + let hasFolder = 1; +} + def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> { let summary = [{ Creates a TensorList which, when stacked, has the value of `tensor`. @@ -7820,6 +8186,8 @@ shape(t) ==> [2, 2, 3] let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> { @@ -8013,8 +8381,9 @@ used to look up the program in the compilation cache. let results = (outs TF_StrTensor:$compilation_status, - TF_StrTensor:$program + Variadic:$program ); + TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 92e6d522125..773025c58df 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -23,6 +23,7 @@ limitations under the License. #define TF_OP_BASE include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffects.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" //===----------------------------------------------------------------------===// @@ -86,6 +87,7 @@ class TF_TensorFlowType : // Any tensor element type allowed in TensorFlow ops def TF_ElementType : Type, "tf.dtype">; @@ -96,20 +98,20 @@ def TF_Tensor : TensorOf<[TF_ElementType]>; //===----------------------------------------------------------------------===// // Integer types -def TF_I32Or64 : IntOfWidths<[32, 64]>; +def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; -def TF_Uint8 : TF_TensorFlowType<"Uint8", "uint8">; -def TF_Uint16 : TF_TensorFlowType<"Uint16", "uint16">; -def TF_Uint32 : TF_TensorFlowType<"Uint32", "uint32">; -def TF_Uint64 : TF_TensorFlowType<"Uint64", "uint64">; +def TF_Uint8 : UI<8>; +def TF_Uint16 : UI<16>; +def TF_Uint32 : UI<32>; +def TF_Uint64 : UI<64>; // Any unsigned integer type -def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>; +def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; // Any signed integer type -def TF_SInt : IntOfWidths<[8, 16, 32, 64]>; +def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; // Any integer type def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>; @@ -192,6 +194,16 @@ def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>; // TensorFlow attribute definitions //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Tensorflow devices metadata + +// Tensorflow GPU device metadata. +def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [ + // GPU device compute capability: major:minor. + StructFieldAttr<"cc_major", I32Attr>, + StructFieldAttr<"cc_minor", I32Attr> +]>; + //===----------------------------------------------------------------------===// // String attribute constraints diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index cc0819d71c9..3743bdda043 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -50,11 +50,14 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { [{Returns indices of layout dependent results.}], "SmallVector", "GetLayoutDependentResults", (ins) >, + InterfaceMethod< + [{Returns the optimal data layout based on the available devices.}], + "StringRef", "GetOptimalLayout", (ins "const RuntimeDevices&":$devices) + >, InterfaceMethod< [{Updates operation attributes and operands to account for the updated data format. If data format is not supported, must return failure.}], - "LogicalResult", "UpdateDataFormat", - (ins "StringRef":$data_format) + "LogicalResult", "UpdateDataFormat", (ins "StringRef":$data_format) >, ]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 8d4c284bcf8..9cec3641d0a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -57,6 +57,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Support/STLExtras.h" // TF:llvm-project #include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -292,6 +293,51 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +//===----------------------------------------------------------------------===// +// Helper functions detect device capabilities from RuntimeDevices. +//===----------------------------------------------------------------------===// + +namespace { +using DeviceNameUtils = ::tensorflow::DeviceNameUtils; +using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; + +bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { + return device.type == ::tensorflow::DEVICE_GPU; +} + +} // namespace + +// Returns true if at least one GPU device is available at runtime. +bool CanUseGpuDevice(const RuntimeDevices &devices) { + return llvm::any_of(devices.device_names(), IsGpuDevice); +} + +// Returns true if all of the GPUs available at runtime support TensorCores +// (NVIDIA compute capability >= 7.0). +bool CanUseTensorCores(const RuntimeDevices &devices) { + auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { + auto md = devices.GetGpuDeviceMetadata(device); + return md ? md->cc_major().getInt() >= 7 : false; + }; + return llvm::all_of( + llvm::make_filter_range(devices.device_names(), IsGpuDevice), + has_tensor_cores); +} + +// Returns true if operation does not have explicit device placement that would +// prevent it from running on GPU device. +bool CanUseGpuDevice(Operation *op) { + auto device_attr = op->getAttrOfType("device"); + if (!device_attr || device_attr.getValue().empty()) return true; + + DeviceNameUtils::ParsedName device; + if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) + return false; + + // We can't use GPU if operation explicitly placed on non-GPU device. + return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; +} + //===----------------------------------------------------------------------===// // TF op helper functions to work with layout transformation. //===----------------------------------------------------------------------===// @@ -566,6 +612,16 @@ void BatchMatMulOp::getCanonicalizationPatterns( // BatchMatMulV2Op //===----------------------------------------------------------------------===// +static LogicalResult Verify(BatchMatMulV2Op op) { + if (!HasRankAtLeast(op.x(), 2)) { + return op.emitOpError("requires lhs operand to have rank at least two"); + } + if (!HasRankAtLeast(op.y(), 2)) { + return op.emitOpError("requires rhs operand to have rank at least two"); + } + return success(); +} + void BatchMatMulV2Op::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); @@ -617,15 +673,6 @@ static LogicalResult Verify(BiasAddOp op) { return success(); } -// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only -// support folding operand transposes. -LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) { - auto ranked = value().getType().dyn_cast(); - if (!ranked || ranked.getRank() != 4) return failure(); - - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - //===----------------------------------------------------------------------===// // BiasAddGradOp //===----------------------------------------------------------------------===// @@ -999,6 +1046,59 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { return success(); } +StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For f32/f16 data type decision depends on the filter size in spatial + // dimensions, for other data types we keep current data format. + if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) + return data_format(); + + // Keep current data format if filter rank is unknown or not equal to 4. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty || filter_ty.getRank() != 4) return data_format(); + + const int64_t d0 = filter_ty.getDimSize(0); + const int64_t d1 = filter_ty.getDimSize(1); + + auto all_ones = [](ArrayAttr arr) -> bool { + return llvm::all_of(arr, [](Attribute attr) -> bool { + return attr.cast().getInt() == 1; + }); + }; + + // Convolutions with 1x1 filter and with strides and dilations all ones, can + // be computed as a GEMM in NHWC data format, and can be up to ~2x times + // faster than convolution in NCHW. + const bool one_by_one = d0 == 1 && d1 == 1; + const bool trivial_strides = all_ones(strides()); + const bool trivial_dilations = all_ones(dilations()); + + // TODO(ezhulenev): This might lead to excessive transposes in the final IR, + // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. + // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all + // users can efficiently use NHWC data format? + if (one_by_one && trivial_strides && trivial_dilations) { + return "NHWC"; + } + + // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because + // it's the fastest option on NVIDIA GPUs with cuDNN library support. + return "NCHW"; +} + //===----------------------------------------------------------------------===// // Conv2dBackpropInputOp //===----------------------------------------------------------------------===// @@ -1495,6 +1595,15 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// ReadVariableOp +//===----------------------------------------------------------------------===// + +void ReadVariableOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // LogicalNotOp //===----------------------------------------------------------------------===// @@ -2612,9 +2721,8 @@ constexpr void CopyBit(const T &src, unsigned src_index, T &dst, // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, // 4, 8) would have dims = 2. struct SparseSliceSpec { - const int64_t dims; - const uint64_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, - shrink_axis_mask; + int64_t dims; + int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; const ArrayRef &begin; const ArrayRef &end; const ArrayRef &strides; @@ -2625,7 +2733,7 @@ struct SparseSliceSpec { // in operand tensor. struct DenseSliceSpec { int64_t dims; - uint64_t begin_mask, end_mask, shrink_axis_mask; + int32_t begin_mask, end_mask, shrink_axis_mask; SmallVectorImpl &begin; SmallVectorImpl &end; SmallVectorImpl &strides; @@ -2638,8 +2746,8 @@ struct DenseSliceSpec { // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then // we need to produce the missing begin_mask, end_mask for the first two // dimensions i.e. foo[:, :, 3:, 2]. -static LogicalResult BuildDenseSliceSpec(const SparseSliceSpec &sparse, - DenseSliceSpec *dense) { +static void BuildDenseSliceSpec(const SparseSliceSpec &sparse, + DenseSliceSpec *dense) { // Build expanded dense begin, end, strides, begin_mask, end_mask, and // shrink_axis_mask. dense->begin.resize(dense->dims); @@ -2689,7 +2797,6 @@ static LogicalResult BuildDenseSliceSpec(const SparseSliceSpec &sparse, dense_index); dense_index++; } - return success(); } // For the given `input_shape`, calculates the sliced shape using the given @@ -2699,7 +2806,7 @@ static LogicalResult BuildDenseSliceSpec(const SparseSliceSpec &sparse, // dimensions in `input_shape`; it will turn them into 1s. At the same time, // canonicalizes `begin`, `end`, and `strides. The calculation follows // tf.StridedSlice op semantics. -static void CalculateSlicedShapeAndBoundRanges( +static void CalculateSlicedShapeFromDenseIndices( MutableArrayRef input_shape, int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask, MutableArrayRef begin, MutableArrayRef end, MutableArrayRef stride) { @@ -2759,21 +2866,59 @@ static void CalculateSlicedShapeAndBoundRanges( } } +// For the given `input_shape`, calculates the sliced shape using the given +// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`, +// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks. +// Updates the result back to `input_shape`. +static void CalculateSlicedShapeFromSparseIndices( + MutableArrayRef input_shape, ArrayRef sparse_begin, + ArrayRef sparse_end, ArrayRef sparse_strides, + int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, + int32_t new_axis_mask, int32_t shrink_axis_mask, + SmallVectorImpl *begin, SmallVectorImpl *end, + SmallVectorImpl *stride) { + int64_t num_sparse_indices = sparse_begin.size(); + SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask, + sparse_begin, sparse_end, sparse_strides}; + + // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is + // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields + // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. + if (sparse.ellipsis_mask == 0) { + Set(sparse.ellipsis_mask, sparse.dims); + sparse.dims++; + } + + int64_t dims = input_shape.size(); + DenseSliceSpec dense = {dims, + /*begin_mask = */ 0, + /*end_mask = */ 0, + /*shrink_axis_mask = */ 0, + *begin, + *end, + *stride}; + + BuildDenseSliceSpec(sparse, &dense); + CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask, + dense.end_mask, dense.shrink_axis_mask, + *begin, *end, *stride); +} + bool StridedSliceOp::GetSlicedBoundRanges( - SmallVectorImpl *begin_indices, - SmallVectorImpl *end_indices, SmallVectorImpl *strides) { + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { // TODO(hinsu): Support lowering for ops with dynamic begin and end values // when it is possible to derive indices based on mask attributes. DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(this->begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(this->end(), m_Constant(&sparse_end_attr)) || - !matchPattern(this->strides(), m_Constant(&sparse_strides_attr))) + if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) return false; auto input_ty = this->input().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return false; auto input_shape = llvm::to_vector<4>(input_ty.getShape()); - int rank = input_shape.size(); SmallVector sparse_begin, sparse_end, sparse_strides; @@ -2784,30 +2929,11 @@ bool StridedSliceOp::GetSlicedBoundRanges( for (const APInt &stride : sparse_strides_attr) sparse_strides.push_back(stride.getSExtValue()); - auto num_sparse_indices = sparse_begin_attr.getNumElements(); - SparseSliceSpec sparse = {num_sparse_indices, - this->begin_mask().getZExtValue(), - this->end_mask().getZExtValue(), - this->ellipsis_mask().getZExtValue(), - this->new_axis_mask().getZExtValue(), - this->shrink_axis_mask().getZExtValue(), - sparse_begin, - sparse_end, - sparse_strides}; - - DenseSliceSpec dense = {rank, - /*begin_mask = */ 0, - /*end_mask = */ 0, - /*shrink_axis_mask = */ 0, - *begin_indices, - *end_indices, - *strides}; - - if (failed(BuildDenseSliceSpec(sparse, &dense))) return false; - - CalculateSlicedShapeAndBoundRanges(input_shape, dense.begin_mask, - dense.end_mask, dense.shrink_axis_mask, - *begin_indices, *end_indices, *strides); + CalculateSlicedShapeFromSparseIndices( + input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); return true; } @@ -2830,44 +2956,38 @@ static LogicalResult Verify(StridedSliceGradOp op) { } bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( - SmallVectorImpl *shape, SmallVectorImpl *begin_indices, - SmallVectorImpl *end_indices, SmallVectorImpl *strides) { - if (this->ellipsis_mask().getZExtValue() || - this->new_axis_mask().getZExtValue() || - this->shrink_axis_mask().getZExtValue()) - return false; // TODO(b/146512589): support these masks - + SmallVectorImpl *input_shape, + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { DenseIntElementsAttr shape_attr; - DenseIntElementsAttr begin_indices_attr, end_indices_attr, strides_attr; - if (!matchPattern(this->shape(), m_Constant(&shape_attr)) || - !matchPattern(this->begin(), m_Constant(&begin_indices_attr)) || - !matchPattern(this->end(), m_Constant(&end_indices_attr)) || - !matchPattern(this->strides(), m_Constant(&strides_attr))) + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; + if (!matchPattern(shape(), m_Constant(&shape_attr)) || + !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) return false; int rank = std::distance(shape_attr.begin(), shape_attr.end()); - shape->clear(); - shape->reserve(rank); - begin_indices->clear(); - begin_indices->reserve(rank); - end_indices->clear(); - end_indices->reserve(rank); - strides->clear(); - strides->reserve(rank); + input_shape->clear(); + input_shape->reserve(rank); + for (const APInt &dim : shape_attr) + input_shape->push_back(dim.getSExtValue()); - for (const APInt &dim : shape_attr) shape->push_back(dim.getSExtValue()); - for (const APInt &index : begin_indices_attr) - begin_indices->push_back(index.getSExtValue()); - for (const APInt &index : end_indices_attr) - end_indices->push_back(index.getSExtValue()); - for (const APInt &stride : strides_attr) - strides->push_back(stride.getSExtValue()); + SmallVector sparse_begin, sparse_end, sparse_strides; - CalculateSlicedShapeAndBoundRanges(*shape, this->begin_mask().getZExtValue(), - this->end_mask().getZExtValue(), - this->shrink_axis_mask().getZExtValue(), - *begin_indices, *end_indices, *strides); + for (const APInt &index : sparse_begin_attr) + sparse_begin.push_back(index.getSExtValue()); + for (const APInt &index : sparse_end_attr) + sparse_end.push_back(index.getSExtValue()); + for (const APInt &stride : sparse_strides_attr) + sparse_strides.push_back(stride.getSExtValue()); + + CalculateSlicedShapeFromSparseIndices( + *input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); return true; } @@ -2887,6 +3007,19 @@ static LogicalResult Verify(TensorListReserveOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TensorListElementShapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto variant_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (variant_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); +} + //===----------------------------------------------------------------------===// // TensorListStackOp //===----------------------------------------------------------------------===// @@ -3157,6 +3290,15 @@ static LogicalResult Verify(VariableShapeOp op) { } } +OpFoldResult VariableShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto resource_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (resource_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); +} + //===----------------------------------------------------------------------===// // WhileOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 9c80c43042b..fbd1a335be1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -19,7 +19,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ -#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project #include "mlir/Dialect/Traits.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project @@ -29,6 +28,10 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index e95fcbbdad3..c1c6a643ef1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -28,7 +28,7 @@ limitations under the License. #define TF_OPS include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" -include "mlir/Analysis/CallInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/OpBase.td" class TF_TensorListInitOp : TF_Op { @@ -64,7 +64,7 @@ class TF_TensorListInitOp : TF_Op { // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. -def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> { +def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> { let summary = "Constant tensor op"; let arguments = (ins @@ -550,4 +550,44 @@ Example: TF_DerivedOperandOrResultHandleShapeAttr shape = TF_DerivedOperandOrResultHandleShapeAttr<"resource">; } + +// Not generated because it begins with an underscore, which isn't allowed by +// the C++ standard. +def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { + let summary = "Internal FusedBatchNorm operation: reserved for internal use"; + + let description = [{ + Do not invoke this operator directly in Python. A fusion optimization is + expected to create these operators. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + Variadic>:$side_input, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$activation_mode, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc new file mode 100644 index 00000000000..6c5485c16dd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" + +namespace mlir { + +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc.inc" + +namespace TF { + +void RuntimeDevices::AddDevice(const ParsedName& device) { + device_names_.push_back(device); +} + +void RuntimeDevices::AddGpuDevice(const ParsedName& device, + const GpuDeviceMetadata& metadata) { + device_names_.push_back(device); + gpu_metadata_.insert({DeviceNameUtils::ParsedNameToString(device), metadata}); +} + +llvm::Optional RuntimeDevices::GetGpuDeviceMetadata( + const ParsedName& device) const { + auto it = gpu_metadata_.find(DeviceNameUtils::ParsedNameToString(device)); + if (it != gpu_metadata_.end()) { + return it->second; + } else { + return llvm::None; + } +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h new file mode 100644 index 00000000000..65887a0c960 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the types used in the standard MLIR TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ + +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "tensorflow/core/util/device_name_utils.h" + +namespace mlir { + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h.inc" + +namespace TF { + +// Tensorflow devices available at runtime with corresponding metadata if it is +// available. It's completely valid to have a device without any metadata +// attached to it. +class RuntimeDevices { + using DeviceNameUtils = ::tensorflow::DeviceNameUtils; + using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; + + public: + // Adds a device with and empty metadata. Device can be of any type. + void AddDevice(const ParsedName& device); + + // Adds a GPU device with GPU specific metadata. + void AddGpuDevice(const ParsedName& device, + const GpuDeviceMetadata& metadata); + + llvm::ArrayRef device_names() const { return device_names_; } + size_t NumDevices() const { return device_names_.size(); } + + // Returns GPU device metadata if it is available, otherwise returns None. + llvm::Optional GetGpuDeviceMetadata( + const ParsedName& device) const; + + private: + llvm::SmallVector device_names_; + // TODO(ezhulenev): Add DenseMapInfo specialization to be able to + // use ParsedName as a key in a DenseMap. + llvm::StringMap gpu_metadata_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index a3bba731581..ef97b234ef7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -77,13 +77,17 @@ TensorFlowType TensorFlowRefType::get(Type type) { case 1: return BoolRefType::get(ctx); case 8: - return Int8RefType::get(ctx); + return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) + : Int8RefType::get(ctx); case 16: - return Int16RefType::get(ctx); + return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) + : Int16RefType::get(ctx); case 32: - return Int32RefType::get(ctx); + return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) + : Int32RefType::get(ctx); case 64: - return Int64RefType::get(ctx); + return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) + : Int64RefType::get(ctx); default: llvm_unreachable("unexpected integer type"); } @@ -121,6 +125,14 @@ Type TensorFlowRefType::RemoveRef() { return mlir::IntegerType::get(32, ctx); case TensorFlowTypes::INT64_REF: return mlir::IntegerType::get(64, ctx); + case TensorFlowTypes::UINT8_REF: + return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx); + case TensorFlowTypes::UINT16_REF: + return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx); + case TensorFlowTypes::UINT32_REF: + return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx); + case TensorFlowTypes::UINT64_REF: + return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx); case TensorFlowTypes::COMPLEX64_REF: return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); case TensorFlowTypes::COMPLEX128_REF: diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def index 0f5f7c17e02..a097a3cad88 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def @@ -19,10 +19,6 @@ limitations under the License. #ifdef HANDLE_TF_TYPE // class, enumerant, name -HANDLE_TF_TYPE(Uint8, UINT8, "uint8") -HANDLE_TF_TYPE(Uint16, UINT16, "uint16") -HANDLE_TF_TYPE(Uint32, UINT32, "uint32") -HANDLE_TF_TYPE(Uint64, UINT64, "uint64") HANDLE_TF_TYPE(Qint8, QINT8, "qint8") HANDLE_TF_TYPE(Qint16, QINT16, "qint16") HANDLE_TF_TYPE(Qint32, QINT32, "qint32") diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 4059aba209f..2898338f8eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -91,7 +91,7 @@ class TensorFlowType : public Type { // Returns true if the specified type is a valid TensorFlow element type. static inline bool IsValidTFElementType(Type type) { return type.isa() || type.isa() || - type.isSignlessInteger() || type.isa(); + type.isa() || type.isa(); } // Returns true if this is a valid TensorFlow tensor type. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir new file mode 100644 index 00000000000..1589de3d661 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir @@ -0,0 +1,44 @@ +// RUN: tf-opt %s -tf-batch-matmul-to-tf-einsum | FileCheck %s --dump-input-on-failure + +func @test_batch_matmul_to_einsum(%arg0: tensor<1x2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> { + // CHECK-LABEL: test_batch_matmul_to_einsum + // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x3xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x3xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + return %0: tensor<1x2x4xf32> +} + +func @test_batch_matmul_broadcast_to_einsum(%arg0: tensor<2x2x4xf32>, %arg1: tensor<2x4x2xf32>) -> tensor<2x2x2xf32> { + // CHECK-LABEL: test_batch_matmul_broadcast_to_einsum + // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<2x2x4xf32>, tensor<2x4x2xf32>) -> tensor<2x2x2xf32> + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x2x4xf32>, tensor<2x4x2xf32>) -> tensor<2x2x2xf32> + return %0: tensor<2x2x2xf32> +} + +func @test_batch_matmul_dynamic_shape_both_arg_to_einsum(%arg0: tensor<1x2x?xf32>, %arg1: tensor) -> tensor<1x2x4xf32> { + // CHECK-LABEL: test_batch_matmul_dynamic_shape_both_arg_to_einsum + // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x?xf32>, tensor) -> tensor<1x2x4xf32> + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x?xf32>, tensor) -> tensor<1x2x4xf32> + return %0: tensor<1x2x4xf32> +} + +func @test_batch_matmul_dynamic_shape_one_arg_to_einsum(%arg0: tensor<1x2x?xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> { + // CHECK-LABEL: test_batch_matmul_dynamic_shape_one_arg_to_einsum + // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x?xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x?xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + return %0: tensor<1x2x4xf32> +} + +func @test_batch_matmul_adj_to_einsum(%arg0: tensor<1x2x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x2x4xf32> { + // CHECK-LABEL: test_batch_matmul_adj_to_einsum + // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...nk->...mn"} : (tensor<1x2x3xf32>, tensor<4x3xf32>) -> tensor<1x2x4xf32> + // CHECK: return %[[RES_EINSUM]] : tensor<1x2x4xf32> + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x2x3xf32>, tensor<4x3xf32>) -> tensor<1x2x4xf32> + return %0: tensor<1x2x4xf32> +} + +func @test_batch_matmulV2_adj_to_einsum(%arg0: tensor<1x3x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> { + // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) {equation = "...km,...kn->...mn"} : (tensor<1x3x2xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + // CHECK: return %[[RES_EINSUM]] : tensor<1x2x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = false} : (tensor<1x3x2xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32> + return %0: tensor<1x2x4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 5bf5b0610ae..158fd3064a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -418,3 +418,37 @@ func @ToBool_0DScalar(%arg0: tensor) -> tensor { %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor return %0 : tensor } + +// CHECK-LABEL: testReadVariableOfOfCast +func @testReadVariableOfOfCast(%arg0: tensor>>) -> tensor<8x40xf32> { + %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor<*x!tf.resource> + %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<8x40xf32> + return %1: tensor<8x40xf32> + +// CHECK: %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<8x40xf32> +// CHECK: return %0 +} + +// CHECK-LABEL: testReadVariableOfOfCastWithTruncate +func @testReadVariableOfOfCastWithTruncate(%arg0: tensor>>) -> tensor<8x40xf32> { + %0 = "tf.Cast"(%arg0) {Truncate = true} : (tensor>>) -> tensor<*x!tf.resource> + %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<8x40xf32> + return %1: tensor<8x40xf32> + +// CHECK: %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<8x40xf32> +// CHECK: return %0 +} + +// CHECK-LABEL: testReadVariableOfOfCastMultiUse +func @testReadVariableOfOfCastMultiUse(%arg0: tensor>>) -> tensor { + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor>>) -> tensor<*x!tf.resource> + %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor + "tf.AssignVariableOp"(%0, %1) : (tensor<*x!tf.resource>, tensor) -> () + return %1: tensor + + // CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor>>) -> tensor<*x!tf.resource> + // CHECK: %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor + // CHECK: "tf.AssignVariableOp"(%0, %1) : (tensor<*x!tf.resource>, tensor) -> () + // CHECK: return %1 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index d9727e94bb6..411599053e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -213,3 +213,21 @@ func @testRemoteDevice() -> tensor<2x2xi32> { // CHECK-NEXT: return [[cst]] : tensor<2x2xi32> return %2: tensor<2x2xi32> } + +// Tests ops that variable shapes are correctly evaluated on static types. +// CHECK-LABEL: func @testVariableShape +func @testVariableShape(%arg0: tensor>>) -> tensor<2xi32> { + %0 = "tf.VariableShape"(%arg0) : (tensor>>) -> tensor<2xi32> + // CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32> + // CHECK-NEXT: return [[cst]] : tensor<2xi32> + return %0: tensor<2xi32> +} + +// Tests ops that tensor list shapes are correctly evaluated on static types. +// CHECK-LABEL: func @testTensorListElementShape +func @testTensorListElementShape(%arg0: tensor>>) -> tensor<2xi32> { + %0 = "tf.TensorListElementShape"(%arg0) : (tensor>>) -> tensor<2xi32> + // CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32> + // CHECK-NEXT: return [[cst]] : tensor<2xi32> + return %0: tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir new file mode 100644 index 00000000000..3dec94a98df --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -0,0 +1,57 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-einsum %s | FileCheck %s + +func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + // CHECK-LABEL: einsum_basic + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> +} + +func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> + return %0 : tensor<2x7x5x4xf32> + // CHECK-LABEL: einsum_4D + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x4x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x4xf32> + // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32> +} + +func @einsum_matrixdotprod(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<7x3x4xf32>) -> tensor<2x5x4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnd,ndh->bfh"}: (tensor<2x5x7x3xf32>, tensor<7x3x4xf32>) -> tensor<2x5x4xf32> + return %0 : tensor<2x5x4xf32> + // CHECK-LABEL: einsum_matrixdotprod + // CHECK: %[[cst:.*]] = constant dense<[2, 5, 21]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[21, 4]> : tensor<2xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<3xi64>) -> tensor<2x5x21xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<7x3x4xf32>, tensor<2xi64>) -> tensor<21x4xf32> + // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32> +} + +func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfd,dnh->bfnh"}: (tensor<3x4x5xf32>, tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> + return %0 : tensor<3x4x6x2xf32> + // CHECK-LABEL: einsum_reshapetail + // CHECK: %[[cst:.*]] = constant dense<[5, 12]> : tensor<2xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[3, 4, 6, 2]> : tensor<4xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<5x6x2xf32>, tensor<2xi64>) -> tensor<5x12xf32> + // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<3x4x12xf32>, tensor<4xi64>) -> tensor<3x4x6x2xf32> + // CHECK: return %[[v2]] : tensor<3x4x6x2xf32> +} + +func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +// CHECK-LABEL: einsum_no_match +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: return %[[v0]] +} +func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +// CHECK-LABEL: einsum_illegal_no_match +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: return %[[v0]] +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/gpu_fusion.mlir b/tensorflow/compiler/mlir/tensorflow/tests/gpu_fusion.mlir new file mode 100644 index 00000000000..6e507f06ef4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/gpu_fusion.mlir @@ -0,0 +1,47 @@ +// RUN: tf-opt %s -tf-gpu-op-fusion | FileCheck %s --dump-input=fail + +// Test the op-fusion pass specific to the GPU target. + +// CHECK-LABEL: func @FusedBatchNormRelu +func @FusedBatchNormRelu(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +// CHECK-NEXT: %[[Y:[a-z0-9]*]], {{.*}}_FusedBatchNormEx +// CHECK-NEXT: return %[[Y]] + %y:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %relu = "tf.Relu"(%y#0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + return %relu : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: func @FusedBatchNormAddRelu +func @FusedBatchNormAddRelu(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +// CHECK-NEXT: %[[Y:[a-z0-9]*]], {{.*}}_FusedBatchNormEx +// CHECK-NEXT: return %[[Y]] + %y:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %add = "tf.AddV2"(%arg0, %y#0) : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + %relu = "tf.Relu"(%add) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + return %relu : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: func @FusedBatchNormAddReluTwoUses +func @FusedBatchNormAddReluTwoUses(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { +// Since the tf.AddV2 op has two uses, we have a _FusedBatchNormEx without the +// Relu activation and we only fuse the add. +// CHECK-NEXT: %[[Y:[a-z0-9]*]], {{.*}}_FusedBatchNormEx +// CHECK-NEXT: %[[relu:[a-z0-9]*]] ={{.*}}Relu"(%[[Y]] +// CHECK-NEXT: return %[[relu]] + %y:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %add = "tf.AddV2"(%arg0, %y#0) : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + %relu = "tf.Relu"(%add) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + return %relu, %add : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: func @TrainingFusedBatchNormRelu +func @TrainingFusedBatchNormRelu(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // We don't fuse in training right now +// CHECK-NEXT: %[[Y:[a-z0-9]*]], {{.*}}FusedBatchNorm +// CHECK-NEXT: %[[relu:[a-z0-9]*]] ={{.*}}Relu"(%[[Y]] +// CHECK-NEXT: return %[[relu]] + %y:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %relu = "tf.Relu"(%y#0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + return %relu : tensor<8x8x8x8xf32> +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index c15aad27209..1544d27009f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -8,6 +8,9 @@ glob_lit_tests( ":test_utilities", ], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "error-message-with-source-info.pbtxt": ["no_oss"], # TODO(b/150946057): to be fixed on oss. + }, test_file_exts = ["pbtxt"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt index 9ae5601fa57..bb5e02fedf2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt @@ -107,5 +107,5 @@ versions { # CHECK: "tf.PartitionedCall"() # CHECK-SAME: Tout = ["tfdtype$DT_UINT8"] # CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]] -# CHECK: func @[[FUNCTION]]() -> tensor -# CHECK: return {{.*}} : tensor +# CHECK: func @[[FUNCTION]]() -> tensor +# CHECK: return {{.*}} : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir new file mode 100644 index 00000000000..3786a26d114 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_60.mlir @@ -0,0 +1,25 @@ +// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always + +module attributes { + tf.devices = {"/device:GPU:0" = {cc_major = 6 : i32, cc_minor = 0 : i32}} +} { + +// CHECK-LABEL: func @transposeConv2D_3x3_f16 +func @transposeConv2D_3x3_f16(%input: tensor<1x28x28x64xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x28x28x64xf16> { + // cuDNN prefers NCHW data format for spatial convolutions in f16 before + // compute capability 7.0 (NVIDIA Tensor Cores). + + // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1) + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x28x28x64xf16>, tensor<3x3x64x64xf16>) + -> tensor<1x28x28x64xf16> + + return %0 : tensor<1x28x28x64xf16> +} + +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir new file mode 100644 index 00000000000..0b2588c38cc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_gpu_cc_70.mlir @@ -0,0 +1,66 @@ +// RUN: tf-opt %s -tf-layout-assignment -verify-diagnostics | FileCheck %s --dump-input=always + +module attributes { + tf.devices = {"/device:GPU:0" = {cc_major = 7 : i32, cc_minor = 0 : i32}} +} { + +// CHECK-LABEL: func @transposeConv2D_3x3_f32 +func @transposeConv2D_3x3_f32(%input: tensor<1x28x28x64xf32>, %filter: tensor<3x3x64x64xf32>) -> tensor<1x28x28x64xf32> { + // cuDNN prefers NCHW data format for spatial convolutions. + // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1) + // CHECK-SAME: data_format = "NCHW" + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NHWC", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x28x28x64xf32>, tensor<3x3x64x64xf32>) + -> tensor<1x28x28x64xf32> + + return %0 : tensor<1x28x28x64xf32> +} + +// CHECK-LABEL: func @transposeConv2D_1x1_f32 +func @transposeConv2D_1x1_f32(%input: tensor<1x64x28x28xf32>, %filter: tensor<1x1x64x64xf32>) -> tensor<1x64x28x28xf32> { + // 1x1 convolution can be computed as a GEMM in NHWC data format. + // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1) + // CHECK-SAME: data_format = "NHWC" + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NCHW", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>) + -> tensor<1x64x28x28xf32> + + // Striding in spatial dimensions does not allow to use GEMM. + // CHECK: "tf.Conv2D"(%arg0, %arg1) + // CHECK-SAME: data_format = "NCHW" + %1 = "tf.Conv2D"(%input, %filter) + { + data_format = "NCHW", + padding = "VALID", + strides = [1, 1, 2, 2] + } : (tensor<1x64x28x28xf32>, tensor<1x1x64x64xf32>) + -> tensor<1x64x14x14xf32> + + return %0 : tensor<1x64x28x28xf32> +} + +// CHECK-LABEL: func @transposeConv2D_3x3_f16 +func @transposeConv2D_3x3_f16(%input: tensor<1x64x28x28xf16>, %filter: tensor<3x3x64x64xf16>) -> tensor<1x64x28x28xf16> { + // To use Tensor Cores for f16 data type, input must be in NHWC data format. + // CHECK: "tf.Conv2D"(%[[INPUT_TRANSPOSE:[0-9]*]], %arg1) + // CHECK-SAME: data_format = "NHWC" + %0 = "tf.Conv2D"(%input, %filter) + { + data_format = "NCHW", + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<1x64x28x28xf16>, tensor<3x3x64x64xf16>) + -> tensor<1x64x28x28xf16> + + return %0 : tensor<1x64x28x28xf16> +} + +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir index 0610cbe8680..b66289ae34b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir @@ -1,52 +1,12 @@ // RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always -// CHECK-LABEL: func @transposeBiasAdd -func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - - // Check that BiasAdd was converted to forced data format, and layout - // dependent arguments and results passed through transpose nodes. - - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} - // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} - // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) - // CHECK: return %[[RES_TRANSPOSE]] - %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> - - return %0 : tensor<1x4x4x8xf32> -} - -// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr -func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} - // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> - // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} - // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]]) - // CHECK: return %[[RES_TRANSPOSE]] - %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> - - return %0 : tensor<1x4x4x8xf32> -} - -// CHECK-LABEL: func @transposeBiasWithUnknownShape -func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> { - - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32> - %0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32> - - return %0 : tensor<*xf32> -} +// IMPORTANT: In the following Conv2D tests tensor shapes do not match +// convolution parameters (stride, dilations, etc...). This test only verifies +// that changing convolution data layout will update all the attributes. // CHECK-LABEL: func @transposeConv2D func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> { - // IMPORTANT: Tensor shapes do not match convolution parameters (stride, - // dilations, etc...). This test only verifies that changing convolution data - // layout will update all the attributes. - // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) @@ -73,3 +33,35 @@ func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32 return %0 : tensor<1x32x32x8xf32> } + +// CHECK-LABEL: func @transposeConv2DWithDefaultAttr +func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<*xf32> +{ + + // CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]]) + + // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: dilations = [1, 4, 2, 3] + // CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6] + // CHECK-SAME: padding = "EXPLICIT" + // CHECK-SAME: strides = [5, 8, 6, 7] + // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32> + + // CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]]) + // CHECK: return %[[RES_TRANSPOSE]] + + // (1) data_format attribute has default value NHWC + // (2) result shape is unknown (check that optimizer does not fail) + %0 = "tf.Conv2D"(%input, %filter) + { + dilations = [1, 2, 3, 4], + explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], + padding = "EXPLICIT", + strides = [5, 6, 7, 8] + } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<*xf32> + + return %0 : tensor<*xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir index a2394cd93c1..ae3592b723f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nchw.mlir @@ -1,24 +1,33 @@ // RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always -// CHECK-LABEL: func @transposeBiasAdd -func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> { +// CHECK-LABEL: func @transposeConv2D +func @transposeConv2D(%arg0: tensor<1x3x32x32xf32>, %arg1: tensor<1x1x3x8xf32>) -> tensor<1x3x32x32xf32> { // Convert input: NCHW -> NHWC %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x3x32x32xf32>, tensor<4xi32>) -> tensor<1x32x32x3xf32> // Compute in NHWC - %2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %2 = "tf.Conv2D"(%1, %arg1) + { + data_format = "NHWC", + padding = "SAME", + strides = [1, 1, 1, 1], + dilations = [1, 1, 1, 1] + } : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x3xf32> // Convert result back: NHWC -> NCHW %3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> - %4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %4 = "tf.Transpose"(%2, %3) : (tensor<1x32x32x3xf32>, tensor<4xi32>) -> tensor<1x3x32x32xf32> - // Check that BiasAdd computed in NCHW format, and all redundant transpose + // Check that Conv2D computed in NCHW format, and all redundant transpose // operations removed from the function. - // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> - // CHECK: return %[[BIAS_ADD]] + // CHECK: %[[CONV:[0-9]*]] = "tf.Conv2D"(%arg0, %arg1) + // CHECK-SAME: data_format = "NCHW" + // CHECK-SAME: -> tensor<1x3x32x32xf32> - return %4 : tensor<1x8x4x4xf32> + // CHECK: return %[[CONV]] + + return %4 : tensor<1x3x32x32xf32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 7b92d0776f8..c5f87c602a3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -392,12 +392,12 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x // Verify that custom types are lowered and have legal output. // CHECK-LABEL: func @DynamicStitch_uint8 -func @DynamicStitch_uint8(%arg0: tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> { +func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> { // CHECK-NOT: tf.DynamicStitch %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> - return %0 : tensor<2x2x!tf.uint8> + %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xui8>) -> tensor<2x2xui8> + return %0 : tensor<2x2xui8> } // CHECK-LABEL: func @DynamicStitch_scalar_item diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index c9db7e0a1dc..706524e39a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -254,4 +254,28 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) return %0 : tensor<*xf32> } + + // CHECK-LABEL: func @while_variant + // CHECK-SAME: -> tensor>> + func @while_variant(%arg0: tensor>>) -> tensor { + // CHECK: tf.While + // CHECK-SAME: -> tensor>> + %0 = "tf.While"(%arg0) {cond = @variant_cond_func, body = @variant_body_func, is_stateless = true} : (tensor>>) -> tensor + // CHECK: tf.ZerosLike + // CHECK-SAME: -> tensor>> + %1 = "tf.ZerosLike"(%0) : (tensor) -> tensor + // CHECK: tf.Identity + // CHECK-SAME: -> tensor>> + %2 = "tf.Identity"(%1) : (tensor) -> tensor + return %2 : tensor + } + // CHECK-LABEL: func @variant_cond_func + func @variant_cond_func(%arg0: tensor>>) -> tensor { + %0 = "tf._SomeOp"() : () -> tensor + return %0 : tensor + } + // CHECK-LABEL: func @variant_body_func + func @variant_body_func(%arg0: tensor>>) -> tensor>> { + return %arg0 : tensor>> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir index 23b77399d4f..e8c5bb59663 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir @@ -6,14 +6,14 @@ func @main() -> tensor { // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor} %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - // CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> - // CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> - // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO]]) // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor) -> tensor // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor, tensor<1xi32>) -> tensor<10xf32> + // CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO]]) // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]]) %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor %id = "tf.Identity"(%stack) : (tensor) -> tensor @@ -52,13 +52,13 @@ func @main() -> tensor { func @main() -> tensor<2xi32> { // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor} : () -> tensor %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() {value = dense<[10, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[ZERO_CONST]], %[[STACK_SHAPE]]) : (tensor, tensor<2xi32>) -> tensor<10x2xi32> // CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK-NEXT: %[[ZERO_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO_SIZE]]) : (tensor>>, tensor<1xi32>) -> () - // CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() {value = dense<[10, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[ZERO_CONST]], %[[STACK_SHAPE]]) : (tensor, tensor<2xi32>) -> tensor<10x2xi32> // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]]) : (tensor>>, tensor<10x2xi32>) -> () %stack = "tf.StackV2"(%size) {elem_type = i32, stack_name = "s"} : (tensor) -> tensor // CHECK-NEXT: %[[PUSH_VAL:.*]] = "tf._SomeOp"() : () -> tensor<2xi32> @@ -151,10 +151,10 @@ func @if_then(%arg0: tensor) -> tensor { } // CHECK: func @if_else(%[[EARG0:.*]]: tensor>>, %[[EARG1:.*]]: tensor>>) func @if_else(%arg0: tensor) -> tensor { - // CHECK-NOT: "tf.StackPushV2" + // CHECK-NOT: "tf.StackPopV2" // CHECK: "tf.Slice" // CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]], - // CHECK-NOT: "tf.StackPushV2" + // CHECK-NOT: "tf.StackPopV2" %pop = "tf.StackPopV2"(%arg0) : (tensor) -> tensor return %arg0 : tensor } @@ -204,7 +204,7 @@ func @callee(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<2xi32> { - // expected-error @+1 {{max size of stack is not a constant.}} + // expected-error @+1 {{unknown max element count}} %stack = "tf.StackV2"(%arg0) {elem_type = i32, stack_name = "s"} : (tensor) -> tensor %elem = "tf._SomeOp"() : () -> tensor<2xi32> %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor<2xi32>) -> tensor<2xi32> @@ -218,7 +218,7 @@ func @main(%arg0: tensor) -> tensor<2xi32> { func @main(%arg0: tensor) -> () { %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - // expected-error @+1 {{cannot infer element shape of stack.}} + // expected-error @+1 {{cannot infer element shape of stack}} %stack = "tf.StackV2"(%max_size) {elem_type = i32, stack_name = "s"} : (tensor) -> tensor %elem = "tf._SomeOp"() : () -> tensor<*xi32> %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor<*xi32>) -> tensor<*xi32> @@ -236,7 +236,7 @@ func @main(%arg0: tensor) -> () { %stack2 = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s2"} : (tensor) -> tensor %if_op = "tf.If"(%arg0, %stack, %stack2) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor, tensor) -> tensor - // expected-error @+1 {{unknown stack.}} + // expected-error @+1 {{unknown stack}} %pop = "tf.StackPopV2"(%if_op) : (tensor) -> tensor "tf.StackCloseV2"(%stack) : (tensor) -> () // CHECK: return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir new file mode 100644 index 00000000000..9e43cea1003 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -0,0 +1,277 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tensor-list-ops-decomposition | FileCheck %s -dump-input-on-failure + +// Test push and pop on a tensor list which is initially empty. + +// CHECK-LABEL: func @main +func @main() -> (tensor, tensor) { + // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor} + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor) -> tensor + // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor, tensor<1xi32>) -> tensor<10xf32> + // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + %id = "tf.Identity"(%tl) : (tensor>>) -> tensor>> + // CHECK-NEXT: %[[PUSHVAL:.*]] = "tf._SomeOp"() + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSHVAL]], %[[UPDATE_SHAPE]]) : (tensor, tensor<1xi32>) -> tensor<1xf32> + // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BROADCAST]], %[[UPDATE_SLICE]], %[[ZERO]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32> + // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[ZERO]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %push = "tf.TensorListPushBack"(%id, %elem) : (tensor>>, tensor) -> tensor>> + // CHECK-NEXT: %[[COPY:.*]] = "tf.Identity"(%[[UPDATE]]) + // CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[NEW_SIZE]], %[[CONST1_1]]) + // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor + %pop:2 = "tf.TensorListPopBack"(%push, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} + // CHECK-NEXT: %[[LENGTH:.*]] = "tf.Reshape"(%[[NEW_SIZE]], %[[SCALAR_SHAPE]]) + %length = "tf.TensorListLength"(%push) : (tensor>>) -> tensor + // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor + return %pop#1, %length: tensor, tensor +} + +// ----- + +// Test get and set, and other operations on a tensor list which has reserved +// initial size. + +// CHECK-LABEL: func @main +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> (tensor, tensor<10xf32>, tensor) +func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { + // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[NUM:.*]] = "tf.Const"() {value = dense<10> : tensor} + %num = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor) -> tensor + // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor, tensor<1xi32>) -> tensor<10xf32> + // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK-NEXT: %[[SIZE:.*]] = "tf.Reshape"(%[[NUM]], %[[SIZE_SHAPE]]) + %tl = "tf.TensorListReserve"(%elem_shape, %num) : (tensor<0xi32>, tensor) -> tensor>> + // CHECK-NEXT: %[[SETVAL:.*]] = "tf._SomeOp"() + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NEXT: %[[SIZE_SHAPE1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK-NEXT: %[[SET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE1]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[SETVAL]], %[[UPDATE_SHAPE]]) : (tensor, tensor<1xi32>) -> tensor<1xf32> + // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BROADCAST]], %[[UPDATE_SLICE]], %[[SET_INDEX]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32> + %set = "tf.TensorListSetItem"(%tl, %arg0, %elem) : (tensor>>, tensor, tensor) -> tensor>> + // CHECK-NEXT: %[[SIZE_SHAPE2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE2]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[UPDATE]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor + %get = "tf.TensorListGetItem"(%set, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor + // CHECK-NEXT: %[[ADDN:.*]] = "tf.AddN"(%[[UPDATE]], %[[BROADCAST]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %addn = "tf.AddN"(%set, %tl) : (tensor>>, tensor>>) -> tensor>> + // CHECK-NEXT: %[[ZEROS_LIKE:.*]] = "tf.ZerosLike"(%[[ADDN]]) : (tensor<10xf32>) -> tensor<10xf32> + %zeros-like = "tf.ZerosLike"(%addn) : (tensor>>) -> tensor>> + // CHECK-NEXT: %[[ADDN2:.*]] = "tf.AddN"(%[[ADDN]], %[[ZEROS_LIKE]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %addn2 = "tf.AddN"(%addn, %zeros-like) : (tensor>>, tensor>>) -> tensor>> + %stack = "tf.TensorListStack"(%addn2, %elem_shape) : (tensor>>, tensor<0xi32>) -> tensor<10xf32> + // CHECK-NEXT: %[[LEN:.*]] = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %length = "tf.TensorListLength"(%addn2) : (tensor>>) -> tensor + // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor, tensor<10xf32>, tensor + return %get, %stack, %length : tensor, tensor<10xf32>, tensor +} + +// ----- + +// Test get on a tensor list created from a tensor. + +// CHECK-LABEL: func @main +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<10xf32>) -> tensor +func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { + // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%arg1) : (tensor<10xf32>) -> tensor<10xf32> + // CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + %tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor>> + // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[BUFFER]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> + // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor + %get = "tf.TensorListGetItem"(%tl, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor + // CHECK-NEXT: return %[[ELEM]] : tensor + return %get: tensor +} + +// ----- + +// Tests while loop. + +// CHECK-LABEL: func @main +func @main() -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + %1:2 = "tf.While"(%tl, %max_size) { + body = @while_body, cond = @while_cond, device = "", is_stateless = false} + : (tensor>>, tensor) -> (tensor>>, tensor) + // CHECK: "tf.Slice" + %pop:2 = "tf.TensorListPopBack"(%1#0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK-NOT: tf.EmptyTensorList + // CHECK: return + return +} +// CHECK: func @while_body(%[[BARG0:.*]]: tensor<10xf32>, %[[BARG1:.*]]: tensor, %[[BARG2:.*]]: tensor<1xi32>) +func @while_body(%arg0: tensor>>, %arg1: tensor) -> (tensor>>, tensor) { + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]]) + %sub = "tf.Sub"(%arg1, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BARG2]], %[[CONST1]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + // CHECK: return %[[UPDATE]], %[[SUB]], %[[ADD]] + return %push, %sub : tensor>>, tensor +} +// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor<1xi32>) +func @while_cond(%arg0: tensor>>, %arg1: tensor) -> tensor { + // CHECK-NEXT: return %[[CARG1]] + return %arg1 : tensor +} + +// ----- + +// Tests IfOp. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + %if_op = "tf.If"(%arg0, %tl) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} + : (tensor, tensor>>) -> tensor>> + // CHECK: "tf.Slice" + %pop:2 = "tf.TensorListPopBack"(%if_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK-NOT: tf.TensorListPopBack + // CHECK: return + return +} +// CHECK: func @if_then(%[[TARG0:.*]]: tensor<10xf32>, %[[TARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) +func @if_then(%arg0: tensor>>) -> tensor>> { + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[TARG1]], %[[CONST1]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + // CHECK: return %[[UPDATE]], %[[ADD]] + return %push : tensor>> +} +// CHECK: func @if_else(%[[EARG0:.*]]: tensor<10xf32>, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) +func @if_else(%arg0: tensor>>) -> tensor>> { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NOT: "tf.TensorListPopBack" + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]]) + // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32> + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: return %[[COPY]], %[[SUB]] + return %pop#0 : tensor>> +} + +// ----- + +// Tests PartitionedCall/StatefulPartitionedCall. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + // CHECK: %[[INIT:.*]] = "tf.BroadcastTo" + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[INIT]], + // CHECK-SAME: f = @callee_tensorlist_decomposed + %call = "tf.StatefulPartitionedCall"(%tl, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor>>, tensor) -> tensor>> + // CHECK: %[[CALL2:.*]]:2 = "tf.PartitionedCall"(%[[INIT]], + // CHECK-SAME: f = @callee_tensorlist_decomposed + %call2 = "tf.PartitionedCall"(%tl, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor>>, tensor) -> tensor>> + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[CALL2]]#0) + // CHECK: "tf.Slice"(%[[COPY]], + %pop:2 = "tf.TensorListPopBack"(%call2, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK-NOT: tf.TensorListPopBack + // CHECK: return + return +} + +// CHECK: func @callee(%[[AARG0:.*]]: tensor>>, %[[AARG1:.*]]: tensor) -> tensor>> +func @callee(%arg0: tensor>>, %arg1: tensor) -> tensor>> { + %elem = "tf._SomeOp"(%arg1) : (tensor) -> tensor + // CHECK: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + return %push : tensor>> +} + +// CHECK: func @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) +// CHECK-NOT: "tf.TensorListPushBack" +// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" +// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG2]], %[[CONST1]]) +// CHECK-NOT: "tf.TensorListPushBack" +// CHECK: return %[[UPDATE]], %[[ADD]] + +// ----- + +// Tests that the pass reports error on unknown maximum size. + +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + // expected-error @+1 {{unknown max element count}} + %tl = "tf.EmptyTensorList"(%elem_shape, %arg0) : (tensor<0xi32>, tensor) -> tensor>> + return +} + +// ----- + +// Tests that the pass reports error on unknown element shape. + +func @main(%arg0: tensor<*xi32>) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // expected-error @+1 {{unknown tensor list element shape}} + %tl = "tf.EmptyTensorList"(%arg0, %max_size) : (tensor<*xi32>, tensor) -> tensor>> + return +} + +// ----- + +// Tests that the pass reports error on pushing elements to a fixed-size tenosr +// list. + +func @main(%arg0: tensor<*xi32>) -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %num = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %tl = "tf.TensorListReserve"(%elem_shape, %num) : (tensor<0xi32>, tensor) -> tensor>> + %elem = "tf._SomeOp"() : () -> tensor + // expected-error @+1 {{cannot push on a fixed-size tensor list}} + %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor>>, tensor) -> tensor>> + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 319660ae4bb..b7d1f3a7104 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -66,17 +66,17 @@ func @testIdentity(%arg0: tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> { // ----- // CHECK-LABEL: func @testBitcast -func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> { - %0 = "tf.Bitcast"(%arg0) : (tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> +func @testBitcast(%arg0: tensor<3x4xui16>) -> tensor<3x4x!tf.quint16> { + %0 = "tf.Bitcast"(%arg0) : (tensor<3x4xui16>) -> tensor<3x4x!tf.quint16> return %0 : tensor<3x4x!tf.quint16> } // ----- // CHECK-LABEL: func @testReverseV2 -func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> { - %0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> - return %0 : tensor<2x4x3x!tf.uint8> +func @testReverseV2(%arg0: tensor<2x4x3xui8>, %arg1: tensor<1xi32>) -> tensor<2x4x3xui8> { + %0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3xui8>, tensor<1xi32>) -> tensor<2x4x3xui8> + return %0 : tensor<2x4x3xui8> } // ----- @@ -210,9 +210,9 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> { // ----- // CHECK-LABEL: func @testMul -func @testMul(%arg0: tensor<2x!tf.uint16>) -> (tensor<2x!tf.uint16>) { - %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2x!tf.uint16>, tensor<2x!tf.uint16>) -> tensor<2x!tf.uint16> - return %0 : tensor<2x!tf.uint16> +func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) { + %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16> + return %0 : tensor<2xui16> } // ----- @@ -236,7 +236,7 @@ func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<1000 func @testReshape(tensor<*xf32>, tensor<*xf32>) -> (tensor<100x100xf32>) { ^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>): %shape1 = constant dense<100.> : tensor<2xf32> - // expected-error @+1 {{must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{must be tensor of 32/64-bit signless integer values}} %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> (tensor<100x100xf32>) return %r1 : tensor<100x100xf32> } @@ -1290,7 +1290,7 @@ func @testValidShape(tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, t // ----- func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> { - // expected-error @+1 {{result #0 must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}} %0 = "tf.Shape"(%arg0) : (tensor<1x32x32x16xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -1334,7 +1334,7 @@ func @testValidShapeN(%arg0 : tensor<1x32x32x16xf32>, %arg1 : tensor<*xf32>) -> // ----- func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> { - // expected-error @+1 {{result #1 must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{result #1 must be tensor of 32/64-bit signless integer values}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<4xf32>) return %0#1 : tensor<4xf32> } @@ -1395,7 +1395,7 @@ func @testVariableShapeMultipleSubtypes(%arg0: tensor<*x!tf.resource>>) -> tensor { - // expected-error @+1 {{result #0 must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}} %0 = "tf.VariableShape"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -1457,7 +1457,7 @@ func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { // Test invalid tf.Less func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> { ^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>): - // expected-error @+1 {{op result #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{op result #0 must be tensor of 1-bit signless integer values}} %0 = "tf.Less"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -1474,7 +1474,7 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor // tf.ConcatV2 with wrong 'axis' element type func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor { - // expected-error @+1 {{operand #2 must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{operand #2 must be tensor of 32/64-bit signless integer values}} %0 = "tf.ConcatV2"(%arg, %arg, %axis) : (tensor<8x16xf32>, tensor<8x16xf32>, tensor) -> tensor return %0 : tensor } @@ -1507,7 +1507,7 @@ func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { // ----- func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { - // expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit integer values}} + // expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signless integer values}} %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor return %0 : tensor } @@ -1515,7 +1515,7 @@ func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { // ----- func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor) -> tensor { - // expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit integer values}} + // expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit signless integer values}} %0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor) -> tensor return %0 : tensor } @@ -2449,3 +2449,17 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor, tensor, tensor) return %result#0 : tensor } + +// ----- + +func @testBatchMatMulV2(%lhs: tensor, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{requires lhs operand to have rank at least two}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor, tensor<10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor) { + // expected-error @+1 {{requires rhs operand to have rank at least two}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x10xf32>, tensor) -> tensor<10x10xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir index b1e2bc36900..07863e3e806 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops_invalid.mlir @@ -94,7 +94,7 @@ func @verifier_replicate_terminator() { // Check that a replicate with 'n' attribute that is less than 2 is invalid. func @verifier_replicate_n() { "tf_device.replicate" () ({ -// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit integer attribute whose minimum value is 2}} +// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 2}} ^entry: tf_device.return }) {n = 1 : i32} : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index d0aa1414723..b9ec020ff59 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -348,7 +348,7 @@ func @invalid_switch(%arg0: tensor<*xf32>) { func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph { %true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) -// expected-error@-1 {{'tf_executor.Switch' op operand #1 must be tensor of 1-bit integer values}} +// expected-error@-1 {{'tf_executor.Switch' op operand #1 must be tensor of 1-bit signless integer values}} tf_executor.fetch %true : tensor<*xf32> } return %result : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir new file mode 100644 index 00000000000..ce44a562aca --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -0,0 +1,88 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-freeze-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail + +module attributes {tf_saved_model.semantics} { + + // Test case: Basic freezing. + + // CHECK-NOT: tf_saved_model.global_tensor + "tf_saved_model.global_tensor"() {sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () + + // CHECK: func @f() + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + return + } +} + +// ----- + + +module attributes {tf_saved_model.semantics} { + + // Test case: Sanity check handling of non-bound inputs. + // The pass shouldn't do anything in this case. + + // CHECK: func @f(%arg0: tensor>> {tf_saved_model.index_path = [0]}) + func @f(%arg0: tensor>> {tf_saved_model.index_path = [0]}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Fail if mutable global tensors are found. + + // expected-error @+1 {{is not immutable}} + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + return + } + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Fail if bound input user is not ReadVariableOp + + "tf_saved_model.global_tensor"() {sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + // expected-error @+1 {{could not rewrite use of immutable bound input}} + "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor>>) -> () + return + } + + func @f_callee(%arg0: tensor>>) { + return + } +} + +// ----- + +// expected-error @+1 {{could not freeze all global tensors in the module}} +module attributes {tf_saved_model.semantics} { + + // Test case: Fail if some global tensor ops remain + + "tf_saved_model.global_tensor"() {sym_name = "v", type = tensor, value = dense<1.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor, value = dense<1.0> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + return + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir index 9e54ff43933..d5cd9004132 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir @@ -12,37 +12,30 @@ func @merge_same_device_variables( %arg1: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg2: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor) { - tf_executor.graph { - // CHECK: tf_executor.island - %island = tf_executor.island { - // CHECK-NEXT: %[[ID_0:.*]] = "tf.IdentityN"(%[[ARG_0]]) - %id0 = "tf.IdentityN"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} - : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> - // CHECK-NEXT: %[[READ_2:.*]] = "tf.ReadVariableOp"(%[[ARG_2]]) - %read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> - %read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource>>) -> tensor<16xf32> - // CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch" - // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]]) - // CHECK-SAME: device_var_reads_indices = [0, 1], - // CHECK-SAME: device_var_updates_indices = [0, -1] - %execute:2 = "tf_device.launch"() ( { - %0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) { - Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>], - Tresults = [tensor<32xf32>, tensor<16xf32>]} - : (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor) -> (tensor<32xf32>, tensor<16xf32>) - tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32> - }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>) - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} - "tf.AssignVariableOp"(%id0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]]) - "tf.AssignVariableOp"(%arg2, %execute#1) : (tensor<*x!tf.resource>>, tensor<16xf32>) -> () - // CHECK-NEXT: tf_executor.yield - tf_executor.yield - } - tf_executor.fetch %island : !tf_executor.control - } + // CHECK-NEXT: %[[ID_0:.*]] = "tf.IdentityN"(%[[ARG_0]]) + %id0 = "tf.IdentityN"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} + : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + // CHECK-NEXT: %[[READ_2:.*]] = "tf.ReadVariableOp"(%[[ARG_2]]) + %read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> + %read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource>>) -> tensor<16xf32> + // CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]]) + // CHECK-SAME: device_var_reads_indices = [0, 1], + // CHECK-SAME: device_var_updates_indices = [0, -1] + %execute:2 = "tf_device.launch"() ( { + %0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) { + Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>], + Tresults = [tensor<32xf32>, tensor<16xf32>]} + : (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor) -> (tensor<32xf32>, tensor<16xf32>) + tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32> + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} + "tf.AssignVariableOp"(%id0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]]) + "tf.AssignVariableOp"(%arg2, %execute#1) : (tensor<*x!tf.resource>>, tensor<16xf32>) -> () + // CHECK-NEXT: return return } @@ -59,35 +52,28 @@ func @merge_replicated_variables( %arg1: tensor, %arg2: tensor<*x!tf.resource>>, %arg3: tensor<*x!tf.resource>>) { - tf_executor.graph { - // CHECK: tf_executor.island - %island = tf_executor.island { - // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) - %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - // CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource>>) - tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource>>) {n = 2 : i32} { - // CHECK-NEXT: "tf_device.launch" - // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]]) - // CHECK-SAME: device_var_reads_indices = [1], - // CHECK-SAME: device_var_updates_indices = [0] - %read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - %execute = "tf_device.launch"() ( { - %0 = "tf.TPUExecute"(%read0, %read1, %arg1) - : (tensor<32xf32>, tensor<32xf32>, tensor) -> tensor<32xf32> - tf_device.return %0 : tensor<32xf32> - }) {device = ""} : () -> tensor<32xf32> - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: }) {device = ""} - "tf.AssignVariableOp"(%r, %execute) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: tf_device.return - tf_device.return - // CHECK-NEXT: } - } - // CHECK-NEXT: tf_executor.yield - tf_executor.yield - } - tf_executor.fetch %island : !tf_executor.control + // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource>>) + tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource>>) {n = 2 : i32} { + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]]) + // CHECK-SAME: device_var_reads_indices = [1], + // CHECK-SAME: device_var_updates_indices = [0] + %read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %execute = "tf_device.launch"() ( { + %0 = "tf.TPUExecute"(%read0, %read1, %arg1) + : (tensor<32xf32>, tensor<32xf32>, tensor) -> tensor<32xf32> + tf_device.return %0 : tensor<32xf32> + }) {device = ""} : () -> tensor<32xf32> + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: }) {device = ""} + "tf.AssignVariableOp"(%r, %execute) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: tf_device.return + tf_device.return + // CHECK-NEXT: } } + // CHECK-NEXT: return return } @@ -112,46 +98,39 @@ func @interferencing_accesses( %arg4: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg5: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg6: tensor<2xf32>) -> (tensor<8xf32>) { - %graph = tf_executor.graph { - // CHECK: tf_executor.island - %island:2 = tf_executor.island { - // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) - %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - // CHECK-NEXT: %[[READ_5:.*]] = "tf.ReadVariableOp"(%[[ARG_5]]) - %read5 = "tf.ReadVariableOp"(%arg5) : (tensor<*x!tf.resource>>) -> tensor<2xf32> - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[ARG_2]]) - "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_5]], %[[ARG_6]]) - "tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource>>, tensor<2xf32>) -> () - %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> - %read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource>>) -> tensor<8xf32> - // CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch" - // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]]) - // CHECK-SAME: device_var_reads_indices = [1, 2], - // CHECK-SAME: device_var_updates_indices = [1, -1] - %execute:3 = "tf_device.launch"() ( { - %0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) { - Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>], - Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]} - : (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor) - -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>) - tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32> - }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>) - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} - "tf.AssignVariableOp"(%arg1, %execute#1) : (tensor<*x!tf.resource>>, tensor<64xf32>) -> () - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0) - "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: %[[READ_3:.*]] = "tf.ReadVariableOp"(%[[ARG_4]]) - %read3 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource>>) -> tensor<8xf32> - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_4]], %[[EXE]]#1) - "tf.AssignVariableOp"(%arg4, %execute#2) : (tensor<*x!tf.resource>>, tensor<8xf32>) -> () - // CHECK-NEXT: tf_executor.yield %[[READ_3]] - tf_executor.yield %read3 : tensor<8xf32> - } - tf_executor.fetch %island#0 : tensor<8xf32> - } - return %graph : tensor<8xf32> + // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // CHECK-NEXT: %[[READ_5:.*]] = "tf.ReadVariableOp"(%[[ARG_5]]) + %read5 = "tf.ReadVariableOp"(%arg5) : (tensor<*x!tf.resource>>) -> tensor<2xf32> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[ARG_2]]) + "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_5]], %[[ARG_6]]) + "tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource>>, tensor<2xf32>) -> () + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> + %read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource>>) -> tensor<8xf32> + // CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]]) + // CHECK-SAME: device_var_reads_indices = [1, 2], + // CHECK-SAME: device_var_updates_indices = [1, -1] + %execute:3 = "tf_device.launch"() ( { + %0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) { + Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>], + Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]} + : (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor) + -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>) + tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32> + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} + "tf.AssignVariableOp"(%arg1, %execute#1) : (tensor<*x!tf.resource>>, tensor<64xf32>) -> () + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0) + "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: %[[READ_3:.*]] = "tf.ReadVariableOp"(%[[ARG_4]]) + %read3 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource>>) -> tensor<8xf32> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_4]], %[[EXE]]#1) + "tf.AssignVariableOp"(%arg4, %execute#2) : (tensor<*x!tf.resource>>, tensor<8xf32>) -> () + // CHECK-NEXT: return %[[READ_3]] + return %read3 : tensor<8xf32> } // ----- @@ -165,30 +144,23 @@ func @interferencing_accesses( func @do_not_merge_multi_read( %arg0: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor) { - tf_executor.graph { - // CHECK: tf_executor.island - %island = tf_executor.island { - // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) - %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - // CHECK-NEXT: %[[READ_1:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) - %read1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - // CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch" - // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]]) - %execute = "tf_device.launch"() ( { - %0 = "tf.TPUExecute"(%read0, %read1, %arg1) { - Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>]} - : (tensor<32xf32>, tensor<32xf32>, tensor) -> (tensor<32xf32>) - tf_device.return %0 : tensor<32xf32> - }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32> - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]) - "tf.AssignVariableOp"(%arg0, %execute) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: tf_executor.yield - tf_executor.yield - } - tf_executor.fetch %island : !tf_executor.control - } + // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // CHECK-NEXT: %[[READ_1:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) + %read1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]]) + %execute = "tf_device.launch"() ( { + %0 = "tf.TPUExecute"(%read0, %read1, %arg1) { + Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>]} + : (tensor<32xf32>, tensor<32xf32>, tensor) -> (tensor<32xf32>) + tf_device.return %0 : tensor<32xf32> + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32> + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]) + "tf.AssignVariableOp"(%arg0, %execute) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: return return } @@ -203,29 +175,118 @@ func @do_not_merge_multi_read( func @do_not_merge_multi_assign( %arg0: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor) { - tf_executor.graph { - // CHECK: tf_executor.island - %island = tf_executor.island { - // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) - %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> - // CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch" - // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]]) - %execute:2 = "tf_device.launch"() ( { - %0:2 = "tf.TPUExecute"(%read0, %arg1) { - Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>]} - : (tensor<32xf32>, tensor) -> (tensor<32xf32>, tensor<32xf32>) - tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<32xf32> - }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<32xf32>) + // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]]) + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]]) + %execute:2 = "tf_device.launch"() ( { + %0:2 = "tf.TPUExecute"(%read0, %arg1) { + Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>]} + : (tensor<32xf32>, tensor) -> (tensor<32xf32>, tensor<32xf32>) + tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<32xf32> + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<32xf32>) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0) + "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#1) + "tf.AssignVariableOp"(%arg0, %execute#1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // CHECK-NEXT: return + return +} + +// ----- + +// Tests that the pass merges only variable reads/writes on the same device, +// with TPUExecutes in a tf_device.parallel_execute. + +// CHECK-LABEL: func @parallel_execute +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_2:.*]]: tensor +func @parallel_execute( + %arg0: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, + %arg1: tensor<*x!tf.resource>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}, + %arg2: tensor) { + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> + // CHECK-NOT: "tf.ReadVariableOp" + // CHECK: "tf_device.parallel_execute" + %pe:2 = "tf_device.parallel_execute"() ( { + // CHECK: "tf_device.launch" + %execute0 = "tf_device.launch"() ( { + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_0]], %[[ARG_2]]) + %0 = "tf.TPUExecute"(%read0, %arg2) : (tensor<32xf32>, tensor) -> tensor<32xf32> // CHECK-NEXT: tf_device.return - // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0) - "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#1) - "tf.AssignVariableOp"(%arg0, %execute#1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () - // CHECK-NEXT: tf_executor.yield - tf_executor.yield - } - tf_executor.fetch %island : !tf_executor.control + tf_device.return %0 : tensor<32xf32> + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32> + tf_device.return %execute0 : tensor<32xf32> + }, { + // CHECK: "tf_device.launch" + %execute1 = "tf_device.launch"() ( { + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_1]], %[[ARG_2]]) + %1 = "tf.TPUExecute"(%read1, %arg2) : (tensor<64xf32>, tensor) -> tensor<64xf32> + // CHECK-NEXT: tf_device.return + tf_device.return %1 : tensor<64xf32> + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1" + }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> tensor<64xf32> + tf_device.return %execute1 : tensor<64xf32> + }) : () -> (tensor<32xf32>, tensor<64xf32>) + // CHECK-NOT: "tf.AssignVariableOp" + "tf.AssignVariableOp"(%arg0, %pe#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg1, %pe#1) : (tensor<*x!tf.resource>>, tensor<64xf32>) -> () + return +} + +// ----- + +// Tests that the pass merges variable reads/writes for TPUExecutes in a +// tf_device.parallel_execute that is replicated (tf_device.replicate). + +// CHECK-LABEL: func @replicated_parallel_execute +// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_2:[a-z0-9]+]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_3:[a-z0-9]+]]: tensor<*x!tf.resource>> +// CHECK-SAME: %[[ARG_4:[a-z0-9]+]]: tensor +func @replicated_parallel_execute( + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<*x!tf.resource>>, + %arg3: tensor<*x!tf.resource>>, + %arg4: tensor) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]+]]: tensor<*x!tf.resource>> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]+]]: tensor<*x!tf.resource>> + tf_device.replicate([%arg0, %arg1] as %ri0: tensor<*x!tf.resource>>, + [%arg2, %arg3] as %ri1: tensor<*x!tf.resource>>) {n = 2 : i32} { + // CHECK-NOT: "tf.ReadVariableOp" + %read0 = "tf.ReadVariableOp"(%ri0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %read1 = "tf.ReadVariableOp"(%ri1) : (tensor<*x!tf.resource>>) -> tensor<64xf32> + // CHECK: "tf_device.parallel_execute" + %pe:2 = "tf_device.parallel_execute"() ( { + // CHECK: "tf_device.launch" + %execute0 = "tf_device.launch"() ( { + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[RI_0]], %[[ARG_4]]) + %0 = "tf.TPUExecute"(%read0, %arg4) : (tensor<32xf32>, tensor) -> tensor<32xf32> + // CHECK-NEXT: tf_device.return + tf_device.return %0 : tensor<32xf32> + }) {device = ""} : () -> tensor<32xf32> + tf_device.return %execute0 : tensor<32xf32> + }, { + // CHECK: "tf_device.launch" + %execute1 = "tf_device.launch"() ( { + // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[RI_1]], %[[ARG_4]]) + %1 = "tf.TPUExecute"(%read1, %arg4) : (tensor<64xf32>, tensor) -> tensor<64xf32> + // CHECK-NEXT: tf_device.return + tf_device.return %1 : tensor<64xf32> + }) {device = ""} : () -> tensor<64xf32> + tf_device.return %execute1 : tensor<64xf32> + }) : () -> (tensor<32xf32>, tensor<64xf32>) + // CHECK-NOT: "tf.AssignVariableOp" + "tf.AssignVariableOp"(%ri0, %pe#0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%ri1, %pe#1) : (tensor<*x!tf.resource>>, tensor<64xf32>) -> () } return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index d9107185954..7ee20d23df3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -5,7 +5,7 @@ // expected-error@+1 {{requires attribute 'tf.versions'}} module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_tf_versions() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -20,7 +20,7 @@ module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_devices() { // expected-error@+1 {{error in fetching TPU compilation/execution devices: no TPU_SYSTEM devices found}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -36,7 +36,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -51,7 +51,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -66,7 +66,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -81,7 +81,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_step_marker_location() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -96,7 +96,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_step_marker_location() { // expected-error@+1 {{bad 'step_marker_location' attribute with value 'test'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -111,7 +111,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP"} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -126,7 +126,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ""} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -141,7 +141,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0, not a string}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1]} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -155,8 +155,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_padding_map() { - // expected-error@+1 {{bad 'padding_map' attribute at index 0 with value 'test'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"]} : () -> () + // expected-error@+1 {{bad 'padding_map' attribute at index 0 with value 'test': failed to parse to tpu::PaddingMap}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -166,12 +166,193 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests `tf_device.launch_func` with missing `topology` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @missing_topology() { + // expected-error@+1 {{requires attribute 'topology'}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with bad `topology` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_topology() { + // expected-error@+1 {{requires attribute 'topology'}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with `topology` attribute resulting in device assignment error. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @invalid_topology() { + // expected-error@+1 {{error in fetching TPU compilation/execution devices}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with missing `device_assignment` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @missing_device_assignment() { + // expected-error@+1 {{requires attribute 'device_assignment'}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with bad `device_assignment` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_device_assignment() { + // expected-error@+1 {{requires attribute 'device_assignment'}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with bad element in `device_assignment` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_element_device_assignment() { + // expected-error@+1 {{bad 'device_assignment' attribute at index 0, not an int}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// The following topology is used in subsequent test cases: +// Proto debug string: +// mesh_shape: 1 +// mesh_shape: 1 +// mesh_shape: 1 +// mesh_shape: 2 +// num_tasks: 1 +// num_tpu_devices_per_task: 2 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// Serialized string: +// "\0A\04\01\01\01\02\10\01\18\02\22\06\00\00\00\00\00\00\00\01" + +// ----- + +// Tests `tf_device.launch_func` with `device_assignment` attribute resulting in device assignment error. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @invalid_device_assignment() { + // expected-error@+1 {{error in fetching TPU compilation/execution devices}} + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + return + } + func @empty_func() { + return + } +} + +// ----- + +// Tests `tf_device.launch_func` with missing `input_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @missing_input_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// The following op sharding is used in subsequent test cases: +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\00" + +// ----- + +// Tests `tf_device.launch_func` with bad `input_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_input_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with mismatched `input_sharding_configuration` attribute size. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @mismatched_size_input_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'input_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + // Tests `tf_device.launch_func` with unsupported operand type. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unsupported_operand_type(%arg0: tensor) { // expected-error@+1 {{failed to determine operand type at index 0: Converting i2 to DataType}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = []} : (tensor) -> tensor + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -181,6 +362,112 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests `tf_device.launch_func` with bad element in `input_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_element_input_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0, not a string}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with unparsable element in `input_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @unparsable_element_input_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with missing `output_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @missing_output_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with bad `output_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_output_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with mismatched `output_sharding_configuration` attribute size. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @mismatched_size_output_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'output_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + + +// Tests `tf_device.launch_func` with bad element in `output_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @bad_element_output_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0, not a string}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + +// Tests `tf_device.launch_func` with unparsable element in `output_sharding_configuration` attribute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func @unparsable_element_output_sharding_configuration(%arg0: tensor) { + // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor + return + } + func @empty_func(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- + // Tests `tf_device.launch_func` with empty `step_marker_location` attribute // defaults to `STEP_MARK_AT_ENTRY`. // @@ -191,7 +478,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @default_step_marker_location func @default_step_marker_location() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : () -> () + "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () // CHECK: metadata // CHECK-SAME: num_replicas: 1 // CHECK-SAME: num_cores_per_replica: 1 @@ -210,7 +497,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @unranked_shape_arg func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: metadata // CHECK-SAME: shape {\0A unknown_rank: true @@ -228,7 +515,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @partial_shape_arg func @partial_shape_arg(%arg0: tensor) -> tensor { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor) -> tensor + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A } @@ -259,7 +546,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @static_shape_arg func @static_shape_arg(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape @@ -284,7 +571,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @resource_arg func @resource_arg(%arg0: tensor<*x!tf.resource>) -> tensor<*x!tf.resource> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> // CHECK: metadata // CHECK: dtype: DT_RESOURCE // CHECK-SAME: kind: VARIABLE @@ -303,7 +590,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @parameter_arg func @parameter_arg(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: metadata // CHECK: dtype: DT_FLOAT // CHECK-SAME: kind: PARAMETER @@ -363,7 +650,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @metadata func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: dtype: DT_INT32 @@ -407,9 +694,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT: "tf.Shape"(%[[ARG_3]]) // CHECK: %[[ARG_0_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]]) // CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]]) - %0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> // CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]]) - // CHECK-SAME: NumDynamicShapes = 2 return %0: tensor<8xi32> } @@ -429,11 +715,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -479,7 +764,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -491,7 +775,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: tf_device.return %[[EXECUTE_OUTPUT]] tf_device.return %2 : tensor @@ -519,7 +803,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { func @single_gpu_launch_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - %1 = "tf_device.launch_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: tf_device.launch_func // CHECK-SAME: device = "gpu0" // CHECK-SAME: func = @gpu0_func @@ -547,11 +831,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -597,11 +880,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -643,11 +925,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -697,11 +978,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -745,11 +1025,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -760,11 +1039,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -803,11 +1081,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -818,11 +1095,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -857,11 +1133,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) - // CHECK-SAME: NumDynamicShapes = 1 // CHECK-SAME: metadata // CHECK-SAME: mlir_module // CHECK-SAME: func @main @@ -916,7 +1191,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests that TPUCompilationResult operations are properly rewritten +// Tests that TPUCompilationResult operations are properly rewritten. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @tpu_compilation_result @@ -928,7 +1203,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUCompileSucceededAssert" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" - %1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor) -> tensor + %1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor @@ -944,3 +1219,179 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %0 : tensor } } + +// ----- + +// Tests devices are set properly for non replicated model parallelism. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @non_replicated_parallel_execute + func @non_replicated_parallel_execute(%arg0: tensor<8xi32>) -> tensor<8xi32> { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1" + %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + return %0 : tensor<8xi32> + } + func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { + return %arg0 : tensor<8xi32> + } +} + +// ----- + +// The following topology is used in subsequent test cases: +// Proto debug string: +// mesh_shape: 1 +// mesh_shape: 2 +// mesh_shape: 1 +// mesh_shape: 2 +// num_tasks: 2 +// num_tpu_devices_per_task: 2 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// Serialized string: +// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01" +// ----- + +// Tests devices are set properly for replicated model parallelism. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @replicated_parallel_execute + func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { + // CHECK: tf_device.replicate + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { + // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" + %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + tf_device.return %1 : tensor<8xi32> + } + return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> + } + func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { + return %arg0 : tensor<8xi32> + } +} + +// ----- + +// Tests devices are set properly for replicated model parallelism with +// outputs to TPU computation placed on logical device 0. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_different_outputs + func @parallel_execute_with_different_outputs(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { + // CHECK: tf_device.replicate + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { + // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUExecute" + // CHECK: device = "TPU_REPLICATED_CORE_1" + %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + tf_device.return %1 : tensor<8xi32> + } + return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> + } + func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { + return %arg0 : tensor<8xi32> + } +} + +// ----- + +// Tests devices are set properly for replicated model parallelism with +// TPU computation with maximal and replicated outputs. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_replicated_output + func @parallel_execute_with_replicated_output(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { + // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" + // CHECK-NEXT: "tf._TPUCompileMlir"() + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0" + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: device = "TPU_REPLICATED_CORE_1" + %1, %2 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func @tpu0_func(%arg0: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + return %1, %3 : tensor<*xi32>, tensor<*xi1> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 7b0c82aaf6a..87eb02eda94 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -16,6 +16,24 @@ func @empty_func() { // ----- +// Tests with a block argument inputs/outputs with no xla sharding op attached +// gets default maximal(0) sharding configuration. +// CHECK-LABEL: func @check_default_sharding_for_block_arg_inputs_outputs +func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) { + "tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () + // CHECK: input_sharding_configuration + // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] + return +} + +func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { + return %arg0 : tensor<*xi32> +} + +// ----- + // Tests with a inputs/outputs with no xla sharding op attached gets // default maximal(0) sharding configuration. // CHECK-LABEL: func @check_default_sharding_for_inputs_outputs diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 4a27e74ad70..5a3f0b6e997 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -234,26 +234,6 @@ func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tenso // ----- -func @batchMatMulV2VectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulV2VectorLhs - // CHECK: %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> -} - -// ----- - -func @batchMatMulV2VectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulV2VectorRhs - // CHECK: %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> -} - -// ----- - func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> return %0 : tensor<10x20xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc new file mode 100644 index 00000000000..6cd82d1472d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +namespace { +// Replace TF BatchMatMul by TF Einsum +struct BatchMatMulToEinsumPass : public FunctionPass { + void runOnFunction() override; +}; + +void BatchMatMulToEinsumPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert, + ConvertTFBatchMatMulToEinsumOp>( + &getContext()); + applyPatternsGreedily(func, patterns); +} + +} // namespace + +template +PatternMatchResult +ConvertTFBatchMatMulToEinsumOp::matchAndRewrite( + BatchMatMulOpType op, PatternRewriter& rewriter) const { + Value input_lhs = op.x(); + Value input_rhs = op.y(); + + if (!input_lhs.getType().isa()) { + // LHS must be a ranked tensor type + return this->matchFailure(); + } + if (!input_rhs.getType().isa()) { + // RHS must be a ranked tensor type + return this->matchFailure(); + } + + auto lhs_type = input_lhs.getType().dyn_cast(); + auto rhs_type = input_rhs.getType().dyn_cast(); + + if (!lhs_type || !rhs_type) { + return this->matchFailure(); + } + + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + Location loc = op.getLoc(); + + // Ensure that input ranks are at least 2. + const int dims_a = lhs_shape.size(); + const int dims_b = rhs_shape.size(); + if (dims_a < 2 || dims_b < 2) { + // Both inputs must have rank >= 2 + return this->matchFailure(); + } + + // einsum equation for batchmatmul + std::string equation("...mk,...kn->...mn"); + + if (op.adj_x()) { + std::swap(equation[3], equation[4]); + } + if (op.adj_y()) { + std::swap(equation[6 + 3], equation[6 + 4]); + } + + llvm::SmallVector inputs = {input_lhs, input_rhs}; + rewriter.replaceOpWithNewOp(op, op.getType(), + /*inputs=*/ValueRange(inputs), + /*equation=*/equation); + + return this->matchSuccess(); +} + +static PassRegistration pass( + "tf-batch-matmul-to-tf-einsum", + "Replace TF BatchMatMul op by TF Einsum op."); + +std::unique_ptr> CreateBatchMatMulToEinsumPass() { + return std::make_unique(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h new file mode 100644 index 00000000000..cd836892ae9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +// Replace TF BatchMatMul by TF Einsum op +template +class ConvertTFBatchMatMulToEinsumOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite( + BatchMatMulOpType op, + PatternRewriter& rewriter) const override; // NOLINT +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BATCHMATMUL_TO_EINSUM_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 992284b320b..73110a724ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -43,6 +43,8 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { pm.addNestedPass(CreateBreakUpIslandsPass()); pm.addNestedPass(TFDevice::CreateReplicateToIslandPass()); pm.addNestedPass(CreateBreakUpIslandsPass()); + pm.addNestedPass(TFDevice::CreateParallelExecuteToIslandsPass()); + pm.addNestedPass(CreateBreakUpIslandsPass()); pm.addNestedPass(TFDevice::CreateLaunchToDeviceAttributePass()); } @@ -71,6 +73,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { // Run island coarsening before shape inference to allow more exact shape // inference using constant folding within islands. pm.addNestedPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); + // TODO(b/150462212): Move graph pruning before island coarsening. + pm.addNestedPass(tf_executor::CreateTFExecutorGraphPruningPass()); // Run shape inference so that tf_executor/tf_device ops created later will // likely to inherit more concrete types. pm.addPass(TF::CreateTFShapeInferencePass()); @@ -90,6 +94,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPUDynamicPaddingMapperPass()); + pm.addPass(CreateTPUShardingIdentificationPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addPass(CreateTPURewritePass()); pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 7c4030ed3f4..05d7a22261a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -27,6 +27,10 @@ def SingleResultAndOperandHaveSameType : Constraint< def IsRank2Tensor : Type, "Rank 2 tensor">; +// Checks if the value has only one user. +def HasOneUse : Constraint>; + + //===----------------------------------------------------------------------===// // Add op patterns. //===----------------------------------------------------------------------===// @@ -175,3 +179,11 @@ def TruncateDivWithSqrtDivisor : Pat<(TF_TruncateDivOp $arg0, def XdivyWithSqrtDivisor : Pat<(TF_XdivyOp $arg0, (TF_SqrtOp $arg1)), (TF_MulNoNanOp (TF_RsqrtOp $arg1), $arg0)>; + + +//===----------------------------------------------------------------------===// +// Cast op followed by a ReadVariable op can be folded into the ReadVariable +//===----------------------------------------------------------------------===// + +def ReadVariableOfCast : Pat<(TF_ReadVariableOp (TF_CastOp:$output $x, BoolAttr:$Truncate)), (TF_ReadVariableOp $x), [(HasOneUse $output)]>; + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc new file mode 100644 index 00000000000..71426b04d99 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace mlir { +namespace TF { +namespace collection_ops_util { + +Value CreateScalarConst(int value, OpBuilder builder, Location loc) { + tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {}); + scalar_tensor.scalar()() = value; + return builder.create( + loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie()); +} + +Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc) { + tensorflow::Tensor shape_tensor(tensorflow::DT_INT32, + {static_cast(r1.size())}); + for (int i = 0; i < r1.size(); ++i) { + shape_tensor.vec()(i) = r1[i]; + } + return builder.create( + loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie()); +} + +Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, + Location loc) { + auto buffer_type = buffer.getType().cast(); + if (buffer_type.getShape().size() == 1) return index; + // Create a concat of index and trailing zeros. + llvm::SmallVector zeros(buffer_type.getShape().size() - 1, 0); + auto zeros_tensor = GetR1Const(zeros, builder, loc); + return builder.create( + loc, + ArrayRef{RankedTensorType::get( + {static_cast(buffer_type.getShape().size())}, + getElementTypeOrSelf(index.getType()))}, + ArrayRef{index, zeros_tensor, CreateScalarConst(0, builder, loc)}, + ArrayRef{}); +} + +Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc) { + auto buffer_type = buffer.getType().cast(); + // Create a slice then reshape to remove the leading trivial dimension of + // size 1. + llvm::SmallVector slice_size = + llvm::to_vector<8>(buffer_type.getShape()); + slice_size[0] = 1; + auto size_const = GetR1Const(slice_size, builder, loc); + auto slice_type = + RankedTensorType::get(slice_size, buffer_type.getElementType()); + auto slice = builder.create( + loc, ArrayRef{slice_type}, + ArrayRef{buffer, GetIndicesForElement(index, buffer, builder, loc), + size_const}, + ArrayRef{}); + auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(), + buffer_type.getElementType()); + auto reshape = builder.create( + loc, ArrayRef{element_type}, + ArrayRef{slice, GetR1Const(element_type.getShape(), builder, loc)}, + ArrayRef{}); + return reshape.output(); +} + +Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, + Location loc) { + auto buffer_type = buffer.getType().cast(); + // Reshape the element to add a leading dimension of size 1, then perform a + // dynamic update slice. + auto slice_shape = llvm::to_vector<8>(buffer_type.getShape()); + slice_shape[0] = 1; + auto update_slice = builder.create( + loc, + ArrayRef{ + RankedTensorType::get(slice_shape, buffer_type.getElementType())}, + ArrayRef{element, GetR1Const(slice_shape, builder, loc)}, + ArrayRef{}); + return builder + .create( + loc, ArrayRef{buffer.getType()}, + ArrayRef{buffer, update_slice, + GetIndicesForElement(index, buffer, builder, loc)}, + ArrayRef{}) + .output(); +} + +TensorType GetSizeType(OpBuilder builder) { + return RankedTensorType::get({1}, builder.getIntegerType(32)); +} + +Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) { + auto size_type = GetSizeType(builder); + return builder.create( + loc, ArrayRef{size_type}, + ArrayRef{scalar, GetR1Const(size_type.getShape(), builder, loc)}, + ArrayRef{}); +} + +LogicalResult CreateInitBufferValue(ArrayRef element_shape, + Value max_size, Operation* op, + Type element_dtype, OpBuilder builder, + Value* buffer) { + auto max_count_op = max_size.getDefiningOp(); + if (!max_count_op) return op->emitOpError("unknown max element count"); + auto max_count_const_op = llvm::dyn_cast(max_count_op); + if (!max_count_const_op) return op->emitOpError("unknown max element count"); + int64_t max_size_const = + (*max_count_const_op.value().getValues().begin()).getSExtValue(); + llvm::SmallVector buffer_shape; + buffer_shape.push_back(max_size_const); + for (int64_t dim : element_shape) { + buffer_shape.push_back(dim); + } + auto zero = CreateScalarConst(0, builder, op->getLoc()); + if (getElementTypeOrSelf(zero.getType()) != element_dtype) { + zero = builder.create( + op->getLoc(), ArrayRef{RankedTensorType::get({}, element_dtype)}, + ArrayRef{zero}, ArrayRef{}); + } + auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype); + auto broadcast = builder.create( + op->getLoc(), ArrayRef{buffer_type}, + ArrayRef{zero, GetR1Const(buffer_shape, builder, op->getLoc())}, + ArrayRef{}); + *buffer = broadcast.output(); + return success(); +} +} // namespace collection_ops_util +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h new file mode 100644 index 00000000000..6b86cafed3f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { +namespace collection_ops_util { + +// This file includes utilities for decomposing collection ops (stack, tensor +// list, tensor array) in TF. We represent such a data structure as a buffer of +// shape [max_element_count, element_shape]. + +// Creates an i32 scalar tf.Const. +Value CreateScalarConst(int value, OpBuilder builder, Location loc); + +// Creates an i32 vector tf.Const. +Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc); + +// Returns the type of the size tensor used to track a data structure's element +// count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is +// easier to concat it with other offsets. +TensorType GetSizeType(OpBuilder builder); + +// Reshapes a scalar value to match the size type tensor. +Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc); + +// Creates ops that represent the indices of the slice for an element in the +// buffer. Requires `index` to have tensor<1xi32> type. +Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, + Location loc); + +// Creates ops that slice the element out of a buffer at the given index. +// Requires `index` to have tensor<1xi32> type. +Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc); + +// Creates ops that copy the buffer and update an element at the given index. +// Requires `index` to have tensor<1xi32> type. +Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, + Location loc); + +// Creates the buffer for the data structure with given element shape, type and +// maximum size. +LogicalResult CreateInitBufferValue(ArrayRef element_shape, + Value max_size, Operation* op, + Type element_dtype, OpBuilder builder, + Value* buffer); +} // namespace collection_ops_util +} // namespace TF +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 7b46c6aec04..c1a87c289bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -31,7 +32,8 @@ LogicalResult ConstantFoldFallbackHook( SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. - if (!inst->hasNoSideEffect()) return failure(); + if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) + return failure(); // If any of the result types are variants, don't try to constant fold them. // This creates opaque variant constants which lose information and would diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc new file mode 100644 index 00000000000..5410ce4faf7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -0,0 +1,296 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Regex.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +namespace { + +// All supported Einsum equations. +enum EinsumEquation { + BatchMatMul, + FourDMatrixDotProd, + ThreeDReshapeTail, + FourDBatchMatMul, + UnsupportedEquation +}; + +// Tokens for parsing the given equation string. +enum EquationToken { + A, + B, + C, + D, + E, + COMMA, + ARROW, +}; +constexpr int kNumSupportedEquationVariables = 5; // A - E for now. + +bool tokenizeEquation(const llvm::StringRef& equation, + std::vector* tokens) { + std::map label_axis_mapping; + int index = 0; + int variable_count = 0; + llvm::Regex r("[[:alpha:]]"); + while (index < equation.size()) { + if (r.match(equation.substr(index, 1))) { + const char ltr = equation[index]; + auto itr = label_axis_mapping.find(ltr); + if (itr == label_axis_mapping.end() && + variable_count < kNumSupportedEquationVariables) { + label_axis_mapping[ltr] = EquationToken(variable_count); + tokens->push_back(EquationToken(variable_count)); + variable_count++; + } else if (itr != label_axis_mapping.end()) { + tokens->push_back(itr->second); + } else { + // Ran out of equation variables. + return false; + } + } else if (equation.substr(index, 1).contains(",")) { + tokens->push_back(COMMA); + } else if ((index < (equation.size() - 1)) && + (equation.substr(index, 2).contains("->"))) { + tokens->push_back(ARROW); + index++; + } else { + // Unallowed character encountered. + return false; + } + index++; + } + return true; +} + +EinsumEquation parseEquation(const std::vector& eqn) { + auto is_equal = [](const std::vector& eqn1, + const std::initializer_list& eqn2) { + return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end()); + }; + // IJK,IKM->IJM + if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) { + return EinsumEquation::BatchMatMul; + } + // BFND,NDH->BFH + if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) { + return EinsumEquation::FourDMatrixDotProd; + } + // BFNH,BTNH->BNFT + if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) { + return EinsumEquation::FourDBatchMatMul; + } + // BFD,DNH->BFNH + if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) { + return EinsumEquation::ThreeDReshapeTail; + } + return EinsumEquation::UnsupportedEquation; +} + +EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) { + std::vector tokens; + if (tokenizeEquation(equation, &tokens)) { + return parseEquation(tokens); + } + return EinsumEquation::UnsupportedEquation; +} + +TF::TransposeOp createTransposeOp(Value value, Location loc, + llvm::ArrayRef permutation, + PatternRewriter* rewriter) { + auto value_type = value.getType().cast(); + auto shape = value_type.getShape(); + auto perm_type = RankedTensorType::get( + {static_cast(permutation.size())}, rewriter->getIntegerType(32)); + auto perm_attr = DenseElementsAttr::get(perm_type, permutation); + auto perm_op = rewriter->create(loc, perm_type, perm_attr); + std::vector transposed_shape(shape.begin(), shape.end()); + for (int i = 0; i < shape.size(); ++i) { + transposed_shape[i] = shape[permutation[i]]; + } + auto transposed_type = + RankedTensorType::get(transposed_shape, value_type.getElementType()); + return rewriter->create(loc, transposed_type, value, + perm_op); +} + +TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, + Type element_type, Location loc, + PatternRewriter* rewriter) { + int64_t shape_rank = shape.size(); + auto shape_spec_type = + RankedTensorType::get({shape_rank}, rewriter->getIntegerType(64)); + Type resultType = RankedTensorType::get(shape, element_type); + auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape); + auto shape_tensor = + rewriter->create(loc, shape_spec_type, constant_attr); + return rewriter->create(loc, resultType, /*tensor=*/value, + /*shape=*/shape_tensor); +} + +} // namespace + +PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( + TF::EinsumOp op, PatternRewriter& rewriter) const { + Type output_type = op.getResult().getType(); + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + Location loc = op.getLoc(); + + if (!lhs.getType().isa()) { + // LHS must be a ranked tensor type + return matchFailure(); + } + if (!rhs.getType().isa()) { + // RHS must be a ranked tensor type + return matchFailure(); + } + + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + // Currently only support static shapes. + if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) { + return matchFailure(); + } + + // Currently support use cases of LHS, RHS dims = 3 or 4 + const int dims_lhs = lhs_shape.size(); + const int dims_rhs = rhs_shape.size(); + if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) { + return matchFailure(); + } + + EinsumEquation einsum_eqn = tokenizeAndParse(op.equation()); + if (einsum_eqn == EinsumEquation::BatchMatMul) { + // Case "IJK,IKM->IJM" + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, bmm_op.getResult()); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { + // Case "BFD,DNH->BFNH" + auto lhs_type = lhs.getType().cast(); + auto lhs_shape = lhs_type.getShape(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + // Reshape RHS + auto rhs_type = rhs.getType().cast(); + auto rhs_shape = rhs_type.getShape(); + auto rhs_element_type = rhs_type.getElementType(); + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim1 = rhs_shape[1]; + const int rhs_dim2 = rhs_shape[2]; + auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2}, + rhs_element_type, loc, &rewriter); + + std::vector bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, lhs, reshaped_rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + auto bmm_element_type = bmm_type.getElementType(); + auto final_reshape = + createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2}, + bmm_element_type, loc, &rewriter); + rewriter.replaceOp(op, {final_reshape.getResult()}); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) { + // Case "BFND,NDH->BFH" + // Reshape LHS + auto lhs_element_type = lhs_type.getElementType(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + const int lhs_dim2 = lhs_shape[2]; + const int lhs_dim3 = lhs_shape[3]; + auto reshaped_lhs = + createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3}, + lhs_element_type, loc, &rewriter); + // Reshape RHS + auto rhs_element_type = rhs_type.getElementType(); + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim1 = rhs_shape[1]; + const int rhs_dim2 = rhs_shape[2]; + auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2}, + rhs_element_type, loc, &rewriter); + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, reshaped_lhs, reshaped_rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, {bmm_op.getResult()}); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::FourDBatchMatMul) { + // Case "BFNH,BTNH->BNFT" + // Transpose LHS + lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, {bmm_op.getResult()}); + return matchSuccess(); + } + return matchFailure(); +} + +// Transform Einsum to other TF Ops for the supported variants. +struct TransformEinsumPass : public FunctionPass { + void runOnFunction() override; +}; + +void TransformEinsumPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert(&getContext()); + applyPatternsGreedily(func, patterns); +} + +static PassRegistration pass( + "tf-einsum", "Transform Einsum to other TF Ops for the supported variants"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h new file mode 100644 index 00000000000..77b0c72aaef --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This pass identifies patterns for certain Einsum Ops and replaces them +// with other equivalent TF Ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +// TF.Einsum provides fully general tensor contractions. For a few select +// cases, we can convert this op to other TF Ops, which in later passes +// properly convert to TF Lite ops. +struct ConvertTFEinsumOp : public OpRewritePattern { + public: + explicit ConvertTFEinsumOp(MLIRContext* context) + : OpRewritePattern(context) {} + + PatternMatchResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter& rewriter) const override; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc new file mode 100644 index 00000000000..82c198ac82f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/UseDefLists.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace tf_saved_model { +namespace { + +// This pass will replace a func's bound inputs which are bound to +// tf.ReadVariable ops global tensors with tf.Const ops inside the func's body. +// If this pass runs successfully, the resultant IR will be guaranteed to: +// +// 1. Not contain any tf_saved_model.global_tensor ops +// 2. Not contain any tf_saved_model.bound_input arg attrs on tf_saved_model +// exported functions +// Else, the pass fails. +// +// The reason this pass has this contract is so that once this succeeds, we know +// the IR is in correct form for inference backends (like lite) that do not +// support resources/variables . Further, this contract also ensures that this +// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower. +struct FreezeGlobalTensorsPass : public ModulePass { + void runOnModule() override; +}; + +void FreezeGlobalTensorsPass::runOnModule() { + auto module = getModule(); + SymbolTable symbol_table(module); + DenseSet frozen_global_tensors; + + for (auto func : module.getOps()) { + SmallVector args_to_erase; + OpBuilder builder(func.getBody()); + + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + SmallVector read_variable_ops_to_erase; + auto global_tensor = LookupBoundInput(func, i, symbol_table); + + if (!global_tensor) continue; + frozen_global_tensors.insert(global_tensor); + + // This pass assumes that all global tensors as immutable (e.g. by a + // previous optimize global tensors pass). If not, this pass has to fail + // since it cannot perform one of its goals. + if (global_tensor.is_mutable()) { + global_tensor.emitError() << "is not immutable"; + return signalPassFailure(); + } + + auto arg = func.getArgument(i); + for (auto user : arg.getUsers()) { + if (auto read_op = llvm::dyn_cast(user)) { + // Collect all read variable ops so that all its uses can be replaced + // with the tf.constant corresponding to the global tensor op. + read_variable_ops_to_erase.push_back(read_op); + } else { + // Current assumption is all users are tf.ReadVariableOp. Need to + // expand this to handle control flow and call ops. + user->emitError() << "could not rewrite use of immutable bound input"; + return signalPassFailure(); + } + } + + // Replace the arg with a tf.Const op in the function body. + auto const_op = builder.create(global_tensor.getLoc(), + global_tensor.value()); + args_to_erase.push_back(i); + for (auto read_op : read_variable_ops_to_erase) { + read_op.getResult().replaceAllUsesWith(const_op.getResult()); + read_op.erase(); + } + } + func.eraseArguments(args_to_erase); + } + // Erase all global tensors that were frozen. + for (auto global_tensor : frozen_global_tensors) { + global_tensor->erase(); + } + + if (!module.getOps().empty()) { + module.emitError() << "could not freeze all global tensors in the module"; + return signalPassFailure(); + } +} + +} // namespace + +// For "opt" to pick up this pass. +static PassRegistration pass( + "tf-saved-model-freeze-global-tensors", + "Freeze tf_saved_model.global_tensor's in func bodies."); + +std::unique_ptr> CreateFreezeGlobalTensorsPass() { + return std::make_unique(); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc new file mode 100644 index 00000000000..0a8d261ee39 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +#define DEBUG_TYPE "tf-gpu-op-fusion" + +namespace mlir { +namespace TF { + +namespace { + +// GpuOpFusionPass is a pass performing fusion specific to GPU targets. +// This is an ad-hoc pass for now, but should be integrated with some notion +// of "target" in the MLIR pipeline in the future. +class GpuOpFusionPass : public FunctionPass { + public: + void runOnFunction() final; +}; + +// %y:6 = "tf.FusedBatchNormV3"(%x, %scale, %offset, %mean, %variance) +// %0 = "tf.Relu"(%y#0) +// -> +// %y:6 = "tf._FusedBatchNormEx"(%x, %scale, %offset, %mean, %variance) +// +// Or: +// %y:6 = "tf.FusedBatchNormV3"(%x, %scale, %offset, %mean, %variance) +// %0 = "tf.AddV2"(%y#0, %side_input) +// %1 = "tf.Relu"(%0) +// -> +// %y:6 = "tf._FusedBatchNormEx"(%x, %scale, %offset, %mean, %variance, +// %side_input) +// TODO(aminim): we should revisit this as a declarative pattern. +// For the second pattern, there is not good way in the framework to handle the +// commutativity of the AddV2: we want the FusedBatchNormV3 on any side. +// Also we need some native calls to handle the "hasOneUse" aspects and the +// optional extra operands for the AddV2 case. +struct ReluToFusedBatchNorm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ReluOp relu_op, + PatternRewriter &rewriter) const override { + Operation *relu_input = relu_op.features().getDefiningOp(); + if (!relu_input) return matchFailure(); + auto batch_norm = dyn_cast_or_null(relu_input); + AddV2Op add_op; + Value side_input; + if (!batch_norm) { + // We don't have a FusedBatchNorm as input to the ReLu, but we can get + // through an AddV2 as well. + add_op = dyn_cast_or_null(relu_input); + if (!add_op) return matchFailure(); + + batch_norm = + dyn_cast_or_null(add_op.x().getDefiningOp()); + if (batch_norm) { + side_input = add_op.y(); + } else { + // Didn't get a FusedBatchNorm on the LHS of the AddV2, try the RHS. + batch_norm = + dyn_cast_or_null(add_op.y().getDefiningOp()); + if (!batch_norm) return matchFailure(); + side_input = add_op.x(); + } + } + assert(batch_norm); + if (batch_norm.is_training()) return matchFailure(); + if (!batch_norm.y().hasOneUse()) return matchFailure(); + + // Build the newly fused operation to replace the batch norm + OperationState state(batch_norm.getLoc(), + FusedBatchNormExOp::getOperationName()); + state.addOperands(batch_norm.getOperands()); + if (side_input) state.operands.push_back(side_input); + state.addTypes(batch_norm.getResultTypes()); + state.addAttributes(batch_norm.getAttrs()); + Operation *op = rewriter.createOperation(state); + rewriter.replaceOp(batch_norm, op->getResults()); + + // Depending on the case, we may fuse the add, the relu, or both. + if (!add_op || add_op.z().hasOneUse()) { + // We fuse the Relu only if the add has a single use, otherwise we only + // fuse the add itself. + op->setAttr("activation_mode", rewriter.getStringAttr("Relu")); + rewriter.replaceOp(relu_op, op->getResult(0)); + } + if (add_op) { + rewriter.replaceOp(add_op, op->getResult(0)); + } + + return matchSuccess(); + } +}; + +void GpuOpFusionPass::runOnFunction() { + FuncOp func = getFunction(); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + applyPatternsGreedily(func, patterns); +} + +} // namespace + +std::unique_ptr> CreateGpuOpFusionPass() { + return std::make_unique(); +} + +static PassRegistration layout_assignment( + "tf-gpu-op-fusion", "Fusion optimization for GPU targets"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/libtftpu.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc similarity index 55% rename from tensorflow/compiler/xla/python/tpu_driver/client/libtftpu.h rename to tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 0562afb2141..281a6011af6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/libtftpu.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTFTPU_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTFTPU_H_ +#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" -#ifdef __cplusplus -extern "C" { -#endif +namespace tensorflow { -typedef struct TfTpuDriver_CompileOp TfTpuDriver_CompileOp; +Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { + if (!config_proto.experimental().enable_mlir_graph_optimization()) { + VLOG(1) << "Skipping MLIR Graph Optimization Pass" + << ", session flag not enabled"; + return Status::OK(); + } -TfTpuDriver_CompileOp* TfTpuDriver_CompileOpConstructor(void* ctx); + // TODO(ezhulenev): Add something here. -void TfTpuDriver_CompileOpExecute(TfTpuDriver_CompileOp* op, void* ctx); - -void TfTpuDriver_CompileOpFree(TfTpuDriver_CompileOp* op); - -#ifdef __cplusplus + return Status::OK(); } -#endif -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTFTPU_H_ +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h new file mode 100644 index 00000000000..955da470494 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" + +namespace tensorflow { + +// Bundle generic MLIR graph optimization passes (some derived from TF Grappler +// graph optimizers) into a single MLIR optimization pass. +class MlirGraphOptimizationPass : public MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "graph_optimization"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_graph_optimization(); + } + + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/lite/python/testdata/test_registerer.i b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc similarity index 55% rename from tensorflow/lite/python/testdata/test_registerer.i rename to tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc index 1cd41c9164d..4681f8a0f33 100644 --- a/tensorflow/lite/python/testdata/test_registerer.i +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%{ -#include "tensorflow/lite/python/testdata/test_registerer.h" -%} +#include -%include "tensorflow/lite/python/testdata/test_registerer.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" + +namespace tensorflow { +namespace { +constexpr int kMlirGraphOptimizationPriority = 0; +} + +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_graph_optimization_pass( + kMlirGraphOptimizationPriority, + std::make_unique()); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index e6c4024d5ec..7d65d16e42d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -17,12 +17,15 @@ limitations under the License. #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #define DEBUG_TYPE "tf-layout-optimization" @@ -90,22 +93,32 @@ Permutation GetDataFormatPermutation(StringRef from_data_format, void LayoutAssignmentPass::runOnFunction() { FuncOp func = getFunction(); - // TODO(ezhulenev): LayoutSensitiveInterface should select the optimal data - // layout if there is no explicitly forced data format. - if (force_data_format_.empty()) return; + // Get runtime devices information from the closest parent module. + RuntimeDevices devices; + ::tensorflow::GetDevicesFromOp(func.getParentOfType(), &devices); + + // If there is no runtime device information and data format is not explicitly + // forced, there is nothing to do. + if (devices.NumDevices() == 0 && force_data_format_.empty()) return; func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) { + // Get desired op data format. + StringRef target_data_format = force_data_format_; + if (target_data_format.empty()) { + target_data_format = layout_sensitive_interface.GetOptimalLayout(devices); + } + // Skip ops that already use target data format. auto data_format = layout_sensitive_interface.data_format(); - if (data_format == force_data_format_) return; + if (data_format == target_data_format) return; // Transpose arguments into the target data format. Permutation args_permutation = - GetDataFormatPermutation(data_format, force_data_format_); + GetDataFormatPermutation(data_format, target_data_format); // Transpose results back to the original data format. Permutation res_permutation = - GetDataFormatPermutation(force_data_format_, data_format); + GetDataFormatPermutation(target_data_format, data_format); if (args_permutation.empty() || res_permutation.empty()) return; @@ -119,7 +132,7 @@ void LayoutAssignmentPass::runOnFunction() { }; // Change operation data format. - if (failed(layout_sensitive_interface.UpdateDataFormat(force_data_format_))) + if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format))) return; // Permute arguments into the target data format. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 1aaceb8ecc7..68617e36f0c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" -#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project @@ -29,6 +28,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/SymbolTable.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 3bd7e164ec7..332e181c9ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -46,9 +46,15 @@ std::unique_ptr> CreateTFShapeInferencePass(); // Optional pass which will unroll BatchMatMul and use only MatMul std::unique_ptr> CreateUnrollBatchMatMulPassPass(); +// Optional pass which will map TF BatchMatMul to TF Einsum +std::unique_ptr> CreateBatchMatMulToEinsumPass(); + // Optimizes Tensorflow graph. std::unique_ptr> CreateTFOptimizePass(); +// Performs specific fusion for GPU targets. +std::unique_ptr> CreateGpuOpFusionPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{ @@ -107,6 +113,10 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function); // removed by resource lifting. Requires known maximum sizes of stacks and // known element shapes of push ops. std::unique_ptr> CreateStackOpsDecompositionPass(); + +// Converts tensor list operations into operations on buffers and sizes. Needs +// static shapes and known max element count. +std::unique_ptr> CreateTensorListOpsDecompositionPass(); } // namespace TF namespace TFControlFlow { @@ -244,6 +254,9 @@ namespace tf_saved_model { // Creates a pass that optimizes tf_saved_model.global_tensor ops. std::unique_ptr> CreateOptimizeGlobalTensorsPass(); +// Creates a pass that freezes tf_saved_model.global_tensor ops. +std::unique_ptr> CreateFreezeGlobalTensorsPass(); + // Creates a pass that uses tf_saved_model dialect linkage information // to mark function visibility. That is, exported functions are marked with // public visibility while the other functions are marked with private diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index b7f8ea263b6..32dbb6f5d34 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -54,8 +54,6 @@ namespace mlir { namespace { -constexpr char kDTypeAttr[] = "dtype"; - // This pass lifts resource variable operations outside of device computation. // This is useful because a lot of accelerator devices can not interact with // resource variables directly.. @@ -188,7 +186,7 @@ void ForwardStoreToLoad(Block* block) { } // Moves resource load operations with the provided `move_load` function. This -// assumes load-store forwarding has been performed on this launch_op such that +// assumes load-store forwarding has been performed on this block such that // all loads of same resource are on its initial values. A `skip_load` functions // is used to indicate whether a load should be skipped. If there are multiple // loads on the same resource, only the first one will be moved, and the later @@ -198,7 +196,7 @@ void HoistResourceLoads( llvm::function_ref move_load) { llvm::SmallDenseMap resource_to_read_ops; - // Only iterate through ops directly in launch_op's body as we can't handle + // Only iterate through ops directly in the body as we can't handle // ops nested deeper in regions. for (Operation& op : llvm::make_early_inc_range(*block)) { auto read_variable_op = dyn_cast(&op); @@ -220,28 +218,25 @@ void HoistResourceLoads( } } -// If there are any stores to resource defined outside of launch_op's body -// region, the stored values must be returned by launch_op and its return op so -// that new values can be used by sunk resource stores. +// If there are any stores to resource defined outside of the block then the +// stored values must be returned so that new values can be used by sunk +// resource stores. // Returns true if any resource variable stored values are appended, otherwise // false. -bool AppendResourceStoreValueToReturn(tf_device::LaunchOp launch_op) { +bool AppendResourceStoreValueToReturn(Block* body) { bool has_resource_store = false; - Block* body = &launch_op.GetBody(); auto old_return = body->getTerminator(); llvm::SmallVector new_return_operands(old_return->getOperands()); - // Only iterate through ops directly in launch_op's body as we can't handle - // ops nested deeper in regions. - for (Operation& op : launch_op.GetBody()) { - auto assign_variable_op = dyn_cast(&op); - if (!assign_variable_op) continue; + // Only iterate through ops directly in the body as we can't handle ops nested + // deeper in regions. + for (auto assign_variable_op : body->getOps()) { Value resource = assign_variable_op.resource(); if (!resource) continue; - // Skip resources created inside of launch_op. - if (resource.getParentRegion() == &launch_op.body()) continue; + // Skip resources created inside of the body. + if (resource.getParentRegion() == body->getParent()) continue; // TODO(ycao): Prevent same value from being returned multiple times. // TODO(ycao): Do not return resource store value if it is defined outside @@ -267,8 +262,7 @@ tf_device::LaunchOp SinkResourceStores(tf_device::LaunchOp launch_op, OpBuilder* builder) { // Update ReturnOp inside launch_op's body to output final values of updated // external resources. - bool has_resource_store = AppendResourceStoreValueToReturn(launch_op); - if (!has_resource_store) return launch_op; + if (!AppendResourceStoreValueToReturn(&launch_op.GetBody())) return launch_op; auto new_return_op = launch_op.GetBody().getTerminator(); llvm::SmallVector new_launch_return_types( @@ -352,9 +346,9 @@ LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { // Holds information about a function's use of a resource argument. struct ResourceArgUseInfo { - bool used; Type data_type; bool updated; + bool used; }; // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the @@ -501,13 +495,13 @@ void LiftArgRetResourcesForFunction( }); // Record the stores in resource_arg_read. for (auto& op : llvm::make_early_inc_range(func_op.front())) { - if (auto write = llvm::dyn_cast(&op)) { - auto arg = write.resource().dyn_cast(); - if (!arg) continue; - // After ForwardStoreToLoad(), there should be just one store for each - // resource. - resource_arg_write[arg] = write; - } + auto write = llvm::dyn_cast(&op); + if (!write) continue; + auto arg = write.resource().dyn_cast(); + if (!arg) continue; + // After ForwardStoreToLoad(), there should be just one store for each + // resource. + resource_arg_write[arg] = write; } // Now change the input types to non-resource and remove the internal loads. auto new_types = llvm::to_vector<8>(func_op.getType().getInputs()); @@ -542,8 +536,8 @@ llvm::SmallVector FilterRange( llvm::SmallVector filtered; for (auto entry : llvm::enumerate(range)) { auto it = resource_arg_uses.find(entry.index()); - if (it != resource_arg_uses.end() && !it->getSecond().used) continue; - filtered.push_back(entry.value()); + if (it == resource_arg_uses.end() || it->getSecond().used) + filtered.push_back(entry.value()); } return filtered; } @@ -882,13 +876,6 @@ LogicalResult HandlePartitionedCallOpCallee( auto module = callee.getParentOfType(); name_base += "_resource_lifted"; auto name = name_base; - { - int64_t counter = 0; - while (module.lookupSymbol(name)) { - auto name = name_base; - name += "_" + std::to_string(counter++); - } - } callee = callee.clone(); callee.setName(name); SymbolTable(module).insert(callee); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index b3474e2faf1..6d2ce76eca8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/OperationSupport.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/SymbolTable.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project @@ -184,20 +185,79 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { return false; } +// Gets the subtype's shape and data type for `type`. Templated to support both +// ResourceType and VariantType. +template +std::unique_ptr>> +GetSubtypesHelper(Type type) { + auto type_with_subtypes = + type.cast().getElementType().dyn_cast(); + if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { + return nullptr; + } + auto shapes_and_types = absl::make_unique>>(); + for (auto subtype : type_with_subtypes.getSubtypes()) { + auto shape = GetShapeFromMlirType(subtype); + // handle_shapes_and_types requires all shapes to be known. So if any + // subtype is unknown, clear the vector. + if (!shape) { + shapes_and_types = nullptr; + break; + } + tensorflow::DataType dtype; + auto status = + tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); + assert(status.ok() && "Unknown element type"); + shapes_and_types->emplace_back(*shape, dtype); + } + return shapes_and_types; +} + +// Gets the subtype's shape and data type for `type`. +std::unique_ptr>> +GetSubtypes(Type type) { + auto subclasses = GetSubtypesHelper(type); + if (subclasses) return subclasses; + return GetSubtypesHelper(type); +} + +// Makes result types match the operand types. Returns if anything is changed. +bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { + bool changed = false; + for (auto entry : llvm::zip(operands, results)) { + Type operand_type = std::get<0>(entry).getType(); + if (operand_type == std::get<1>(entry).getType()) continue; + std::get<1>(entry).setType(operand_type); + changed = true; + } + return changed; +} + } // namespace bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, int64_t graph_version) { assert(tf_dialect == op->getDialect()); + // The shape function of these ops sometimes does not propagate subtypes + // (handle shapes) for resource and variant types. We use a simple passthrough + // to make sure they are preserved in the output. + if (isa(op) || isa(op) || + isa(op) || isa(op)) { + return PassThroughOperandTypes(op->getOperands(), op->getResults()); + } // If no result for this op needs shape inference, we have a fast-path return. - // But if the type is a resource, we do not skip it because we might not have - // the handle shapes. + // But if the type is a resource/variant, we do not skip it because we might + // not have the handle shapes. if (llvm::all_of(op->getResultTypes(), [](Type type) { auto shape_type = type.dyn_cast(); return !shape_type || (shape_type.hasStaticShape() && - !shape_type.getElementType().isa()); + !shape_type.getElementType().isa() && + !shape_type.getElementType().isa()); })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" << op->getName() << "'.\n";); @@ -282,29 +342,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, if (auto shape = GetShapeFromMlirType(operand_type)) { input_shapes[index] = *shape; } - // Collect the handle shapes and types for a resource. - if (auto resource_type = operand_type.cast() - .getElementType() - .dyn_cast()) { - if (resource_type.getSubtypes().empty()) continue; - auto shapes_and_types = absl::make_unique>>(); - for (auto subtype : resource_type.getSubtypes()) { - auto shape = GetShapeFromMlirType(subtype); - // handle_shapes_and_types requires all shapes to be known. So if any - // subtype is unknown, clear the vector. - if (!shape) { - shapes_and_types = nullptr; - break; - } - tensorflow::DataType dtype; - auto status = - tensorflow::ConvertToDataType(subtype.getElementType(), &dtype); - assert(status.ok() && "Unknown element type"); - shapes_and_types->emplace_back(*shape, dtype); - } - handle_shapes_and_types[index] = std::move(shapes_and_types); - } + // Collect the handle shapes and types for a resource/variant. + handle_shapes_and_types[index] = GetSubtypes(operand_type); } // Perform the shape inference using an InferenceContext with the input @@ -346,8 +385,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return RankedTensorType::get(shape, element_type); }; auto new_element_type = shaped_type.getElementType(); - // Populate the handle shapes for a resource. - if (auto resource_type = new_element_type.dyn_cast()) { + // Populate the handle shapes for a resource/variant. + if (new_element_type.isa() || + new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { llvm::SmallVector subtypes; @@ -359,7 +399,11 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, assert(status.ok() && "Unknown element type"); subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type)); } - new_element_type = TF::ResourceType::get(subtypes, op->getContext()); + if (new_element_type.isa()) { + new_element_type = TF::ResourceType::get(subtypes, op->getContext()); + } else { + new_element_type = TF::VariantType::get(subtypes, op->getContext()); + } } } auto new_type = get_tensor_type(shape_handle, new_element_type); @@ -452,11 +496,13 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, return PropagateShapeToFunctions(module, while_op.getOperandTypes(), {while_op.cond(), while_op.body()}, graph_version, max_iteration); - } else if (auto call_op = dyn_cast(op)) { - if (call_op.f().isa()) - return PropagateShapeToFunctions(module, call_op.getOperandTypes(), - {call_op.f().getRootReference()}, - graph_version, max_iteration); + } else if (auto call_op = dyn_cast(op)) { + CallInterfaceCallable callable = call_op.getCallableForCallee(); + if (SymbolRefAttr sym = callable.dyn_cast()) { + return PropagateShapeToFunctions( + module, call_op.getArgOperands().getTypes(), {sym.getRootReference()}, + graph_version, max_iteration); + } } // TODO(ycao): Implement support for Call op, including function reuse. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index b58a2402f4e..4033d522091 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -52,6 +53,8 @@ namespace mlir { namespace { +namespace cutil = TF::collection_ops_util; + // A pass that converts stack operations to tensor operations and read/assign // ops on local variables. A later resource lifting pass can further remove the // local variables. @@ -106,88 +109,14 @@ TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value, ArrayRef{}); } -// Creates an i32 scalar tf.Const. -TF::ConstOp CreateScalarConst(int value, OpBuilder builder, Location loc) { - tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {}); - scalar_tensor.scalar()() = value; - return builder.create( - loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie()); -} - -// Creates an i32 vector tf.Const. -TF::ConstOp GetR1Const(ArrayRef r1, OpBuilder builder, Location loc) { - tensorflow::Tensor shape_tensor(tensorflow::DT_INT32, - {static_cast(r1.size())}); - for (int i = 0; i < r1.size(); ++i) { - shape_tensor.vec()(i) = r1[i]; - } - return builder.create( - loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie()); -} - -// Creates a rank-1 op that represents the offsets of the stack element in the -// stack buffer. -Value GetIndicesForStackElement(Value index, Value stack_value, - OpBuilder builder, Location loc) { - auto stack_type = stack_value.getType().cast(); - if (stack_type.getShape().size() == 1) return index; - llvm::SmallVector zeros(stack_type.getShape().size() - 1, 0); - auto zeros_tensor = GetR1Const(zeros, builder, loc); - return builder.create( - loc, - ArrayRef{RankedTensorType::get( - {static_cast(stack_type.getShape().size())}, - getElementTypeOrSelf(index.getType()))}, - ArrayRef{index, zeros_tensor, CreateScalarConst(0, builder, loc)}, - ArrayRef{}); -} - -// Returns the type of the local variable for the stack size. It is a -// tensor<1xi32>, and we use R1 instead of a scalar because it is easier to -// concat it with other offsets. +// Returns the type of the local variable for the stack size. Type GetSizeVarType(OpBuilder builder) { - auto size_type = RankedTensorType::get({1}, builder.getIntegerType(32)); + auto size_type = cutil::GetSizeType(builder); return RankedTensorType::get( {}, TF::ResourceType::get(ArrayRef{size_type}, builder.getContext())); } -// Creates the buffer and size local variables for a stack. -std::pair CreateVariablesForStack(TensorType stack_tensor_type, - TF::StackV2Op stack) { - OpBuilder builder(stack); - auto size_var_type = GetSizeVarType(builder); - auto var_type = RankedTensorType::get( - {}, TF::ResourceType::get(ArrayRef{stack_tensor_type}, - stack.getContext())); - auto local_var = builder.create( - stack.getLoc(), ArrayRef{var_type}, ArrayRef{}, - ArrayRef{}); - auto local_size_var = builder.create( - stack.getLoc(), ArrayRef{size_var_type}, ArrayRef{}, - ArrayRef{}); - - // Zero-initialize the local vars. - WriteLocalVariable(local_size_var, GetR1Const({0LL}, builder, stack.getLoc()), - builder, stack.getLoc()); - auto zero = CreateScalarConst(0, builder, stack.getLoc()).output(); - if (getElementTypeOrSelf(zero.getType()) != - stack_tensor_type.getElementType()) { - zero = builder.create( - stack.getLoc(), - ArrayRef{ - RankedTensorType::get({}, stack_tensor_type.getElementType())}, - ArrayRef{zero}, ArrayRef{}); - } - auto broadcast = builder.create( - stack.getLoc(), ArrayRef{stack_tensor_type}, - ArrayRef{zero, GetR1Const(stack_tensor_type.getShape(), builder, - stack.getLoc())}, - ArrayRef{}); - WriteLocalVariable(local_var, broadcast, builder, stack.getLoc()); - return {local_var, local_size_var}; -} - // Tries to infer the stack element type with full shape based on its uses. llvm::Optional GetStackElementType(Value stack, ModuleOp module) { @@ -449,7 +378,7 @@ LogicalResult HandlePartitionedCallOp( if (arg_it == info.stack_var_arg_to_size_arg.end()) continue; auto it = data_var_to_size_var.find(call.getOperand(i)); if (it == data_var_to_size_var.end()) { - call.emitOpError("Unknown stack."); + call.emitOpError("unknown stack"); return failure(); } assert(arg_it->second == new_operands.size()); @@ -532,25 +461,32 @@ LogicalResult HandleStackV2Op( // Create a buffer variable and a size variable to replace the stack. auto elem_type = GetStackElementType(stack.handle(), module); if (!elem_type.hasValue()) { - return stack.emitOpError("cannot infer element shape of stack."); + return stack.emitOpError("cannot infer element shape of stack"); } - auto size_op = stack.max_size().getDefiningOp(); - if (!size_op || !llvm::isa(size_op)) { - return stack.emitOpError("max size of stack is not a constant."); + OpBuilder builder(stack); + Value buffer; + if (failed(cutil::CreateInitBufferValue( + elem_type->getShape(), stack.max_size(), stack, + elem_type->getElementType(), builder, &buffer))) { + return failure(); } - int64_t max_size = - (*llvm::cast(size_op).value().getValues().begin()) - .getSExtValue(); - llvm::SmallVector stack_shape; - stack_shape.push_back(max_size); - for (int64_t dim : elem_type->getShape()) stack_shape.push_back(dim); - auto stack_tensor_type = - RankedTensorType::get(stack_shape, elem_type->getElementType()); - Value local_var; - Value local_size_var; - std::tie(local_var, local_size_var) = - CreateVariablesForStack(stack_tensor_type, stack); - stack.replaceAllUsesWith(local_var); + auto size_var_type = GetSizeVarType(builder); + auto var_type = RankedTensorType::get( + {}, TF::ResourceType::get( + ArrayRef{buffer.getType().cast()}, + stack.getContext())); + auto local_var = builder.create( + stack.getLoc(), ArrayRef{var_type}, ArrayRef{}, + ArrayRef{}); + auto local_size_var = builder.create( + stack.getLoc(), ArrayRef{size_var_type}, ArrayRef{}, + ArrayRef{}); + // Zero-initialize the local vars. + WriteLocalVariable(local_size_var, + cutil::GetR1Const({0LL}, builder, stack.getLoc()), builder, + stack.getLoc()); + WriteLocalVariable(local_var, buffer, builder, stack.getLoc()); + stack.handle().replaceAllUsesWith(local_var); (*data_var_to_size_var)[local_var] = local_size_var; stack.erase(); return success(); @@ -561,7 +497,7 @@ LogicalResult HandleStackPushV2Op( llvm::SmallDenseMap* data_var_to_size_var) { auto it = data_var_to_size_var->find(push.handle()); if (it == data_var_to_size_var->end()) { - return push.emitOpError("unknown stack."); + return push.emitOpError("unknown stack"); } // Push output simply forward the input element. push.replaceAllUsesWith(push.elem()); @@ -569,31 +505,13 @@ LogicalResult HandleStackPushV2Op( // Read the current buffer and size. auto stack_val = ReadLocalVariable(push.handle(), builder, push.getLoc()); auto index = ReadLocalVariable(it->getSecond(), builder, push.getLoc()); - auto stack_buffer_type = stack_val.getType().cast(); - auto slice_shape = llvm::to_vector<8>(stack_buffer_type.getShape()); - slice_shape[0] = 1; - // Caculate the updated buffer. - auto update_slice = builder.create( - push.getLoc(), - ArrayRef{RankedTensorType::get(slice_shape, - stack_buffer_type.getElementType())}, - ArrayRef{push.elem(), - GetR1Const(slice_shape, builder, push.getLoc())}, - ArrayRef{}); stack_val = - builder - .create( - push.getLoc(), ArrayRef{stack_val.getType()}, - ArrayRef{stack_val, update_slice, - GetIndicesForStackElement( - index, stack_val, builder, push.getLoc())}, - ArrayRef{}) - .output(); + cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc()); // Assign the new buffer and size. WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc()); index = builder.create( push.getLoc(), ArrayRef{index.getType()}, - ArrayRef{index, GetR1Const({1}, builder, push.getLoc())}, + ArrayRef{index, cutil::GetR1Const({1}, builder, push.getLoc())}, ArrayRef{}); WriteLocalVariable(it->getSecond(), index, builder, push.getLoc()); push.erase(); @@ -605,7 +523,7 @@ LogicalResult HandleStackPopV2Op( llvm::SmallDenseMap* data_var_to_size_var) { auto it = data_var_to_size_var->find(pop.handle()); if (it == data_var_to_size_var->end()) { - return pop.emitOpError("unknown stack."); + return pop.emitOpError("unknown stack"); } OpBuilder builder(pop); // Read the current buffer and size. @@ -613,31 +531,10 @@ LogicalResult HandleStackPopV2Op( auto size = ReadLocalVariable(it->getSecond(), builder, pop.getLoc()); auto new_size = builder.create( pop.getLoc(), ArrayRef{size.getType()}, - ArrayRef{size, GetR1Const({1}, builder, pop.getLoc())}, + ArrayRef{size, cutil::GetR1Const({1}, builder, pop.getLoc())}, ArrayRef{}); - auto stack_val_type = stack_val.getType().cast(); - auto elem_type = RankedTensorType::get(stack_val_type.getShape().drop_front(), - stack_val_type.getElementType()); - // Slice the buffer to get the element. - llvm::SmallVector slice_size; - slice_size.push_back(1); - for (int64_t dim : elem_type.getShape()) slice_size.push_back(dim); - auto size_const = GetR1Const(slice_size, builder, pop.getLoc()); - auto slice_type = - RankedTensorType::get(slice_size, stack_val_type.getElementType()); - auto slice = builder.create( - pop.getLoc(), ArrayRef{slice_type}, - ArrayRef{ - stack_val, - GetIndicesForStackElement(new_size, stack_val, builder, pop.getLoc()), - size_const}, - ArrayRef{}); - auto pop_val = builder.create( - pop.getLoc(), ArrayRef{elem_type}, - ArrayRef{slice, - GetR1Const(elem_type.getShape(), builder, pop.getLoc())}, - ArrayRef{}); - pop.replaceAllUsesWith(pop_val.output()); + auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc()); + pop.replaceAllUsesWith(pop_val); // Update the size. WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc()); pop.erase(); @@ -688,8 +585,7 @@ LogicalResult DecomposeStackOpsInternal( } else if (auto pcall = llvm::dyn_cast(&op)) { if (!pcall.f().isa()) { return pcall.emitOpError( - "Stack decomposition does not support call with nested " - "references."); + "stack decomposition does not support call with nested references"); } if (failed(HandlePartitionedCallOp( pcall, module.lookupSymbol(pcall.f().getRootReference()), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc new file mode 100644 index 00000000000..8b1ba7d1d30 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -0,0 +1,695 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace mlir { + +namespace { + +namespace cutil = TF::collection_ops_util; + +// A pass that rewrites tensor list operations to tensor operations on buffers +// and size values. +// +// This pass requires that the full shape of the tensor list can be inferred: 1) +// the maximum size needs to be a constant and 2) the element shape needs to be +// constant. +// +// A tensor list creation op "tf.EmptyTensorList"/"tf.TensorListReserve" will be +// turned in to a zero-initialized buffer, and the size is initialized to a 0 +// for "tf.EmptyTensorList" or the specified size for "tf.TensorListReserve". +// Each push will be turned into "tf.XlaDynamicUpdateSlice" with the incremented +// size, and each pop will be turned into a "tf.Slice" and a copy of the buffer +// with decremented size. Each SetItem will be turned into a +// "tf.XlaDynamicUpdateSlice" with unchanged size, and each GetItem will be +// turned into a "tf.Slice". +// +// The pass also works across control flow and functional calls. +struct TensorListOpsDecompositionPass + : public ModulePass { + void runOnModule() override; +}; + +// Updates func's type according to its current arguments and return values. +void UpdateFuncType(FuncOp func) { + llvm::SmallVector arg_types; + for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); + func.setType(FunctionType::get( + arg_types, + llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()), + func.getContext())); +} + +// Holds the size value of a tensor list and whether the size is statically +// known (fixed). +struct SizeInfo { + Value size; + bool fixed; +}; + +// Modifies a function's signature to rewrite tensor list arguments to buffers +// and sizes. +void ModifyFunctionSignature( + FuncOp func, Type size_type, + llvm::SmallDenseMap* buffer_to_size, + llvm::function_ref(int64_t)> arg_to_buffer_type, + llvm::function_ref arg_buffer_size_is_fixed) { + auto new_input_types = llvm::to_vector<8>(func.getType().getInputs()); + int64_t original_arg_count = new_input_types.size(); + for (int64_t i = 0; i < original_arg_count; ++i) { + auto buffer_type = arg_to_buffer_type(i); + if (!buffer_type.hasValue()) continue; + func.getArgument(i).setType(*buffer_type); + new_input_types[i] = *buffer_type; + auto size_arg = func.front().addArgument(size_type); + new_input_types.push_back(size_arg.getType()); + if (buffer_to_size) { + (*buffer_to_size)[func.getArgument(i)] = {size_arg, + arg_buffer_size_is_fixed(i)}; + } + } + UpdateFuncType(func); +} + +// Holds information about a decomposed callee function for +// PartitionedCall/StatefulPartitionedCall. +struct PartitionedCallDecompositionInfo { + bool signature_change; + FuncOp decomposed_callee; + llvm::SmallDenseMap buffer_arg_to_size_arg; + // Each element is a tuple of (buffer_return_index, size_return_index, + // fixed_size). + llvm::SmallVector, 8> + buffer_ret_to_size_ret; +}; + +LogicalResult DecomposeTensorListOpsInternal( + Block*, ModuleOp, llvm::SmallDenseMap*, + llvm::SmallDenseMap*); + +// Adds the corresponding sizes of tensor list buffers in func's return values +// to the list of return values. Returns the mapping from the buffer indices to +// the added size indices, which is a list of tuples (buffer_return_index, +// size_return_index, fixed_size). +llvm::SmallVector, 8> +AddTensorListSizesToReturn( + FuncOp func, const llvm::SmallDenseMap& buffer_to_size) { + auto old_return = func.front().getTerminator(); + auto new_returns = llvm::to_vector<8>(old_return->getOperands()); + llvm::SmallVector, 8> + output_buffer_to_size; + for (auto retval : llvm::enumerate(old_return->getOperands())) { + auto it = buffer_to_size.find(retval.value()); + if (it == buffer_to_size.end()) continue; + output_buffer_to_size.emplace_back(retval.index(), new_returns.size(), + it->getSecond().fixed); + new_returns.push_back(it->getSecond().size); + } + OpBuilder(old_return).create(old_return->getLoc(), new_returns); + old_return->erase(); + UpdateFuncType(func); + return output_buffer_to_size; +} + +LogicalResult HandleWhileOp( + TF::WhileOp while_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + // Rewrite body. + auto body = module.lookupSymbol(while_op.body()); + llvm::SmallDenseMap body_map; + auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional { + auto it = buffer_to_size->find(while_op.getOperand(index)); + if (it == buffer_to_size->end()) return llvm::None; + return it->getFirst().getType(); + }; + auto arg_buffer_size_is_fixed = [&](int64_t index) { + return (*buffer_to_size)[while_op.getOperand(index)].fixed; + }; + OpBuilder builder(while_op); + ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map, + find_arg_tensor_list_type, arg_buffer_size_is_fixed); + if (failed(DecomposeTensorListOpsInternal( + &body.front(), module, &body_map, + decomposed_partitioned_call_callees))) { + return failure(); + } + auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map); + + // Rewrite cond. + auto cond = module.lookupSymbol(while_op.cond()); + llvm::SmallDenseMap cond_map; + ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map, + find_arg_tensor_list_type, arg_buffer_size_is_fixed); + if (failed(DecomposeTensorListOpsInternal( + &cond.front(), module, &cond_map, + decomposed_partitioned_call_callees))) { + return failure(); + } + if (output_buffer_to_size.empty()) { + return success(); + } + // Create the new while op. + auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); + auto new_output_shapes = + llvm::to_vector<8>(while_op.output_shapes().getValue()); + for (int64_t i = 0; i < while_op.getNumResults(); ++i) { + auto it = buffer_to_size->find(while_op.getOperand(i)); + if (it == buffer_to_size->end()) continue; + new_while_operands.push_back(it->getSecond().size); + if (!new_output_shapes.empty()) { + // Size is a scalar shape. + tensorflow::TensorShapeProto shape_proto; + new_output_shapes.push_back(builder.getStringAttr( + tensorflow::mangling_util::MangleShape(shape_proto))); + } + } + auto new_while = + builder.create(while_op.getLoc(), body.getType().getInputs(), + new_while_operands, while_op.getAttrs()); + new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { + new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + while_op.replaceAllUsesWith( + new_while.getResults().take_front(while_op.getNumResults())); + while_op.erase(); + return success(); +} + +LogicalResult HandleIfOp( + TF::IfOp if_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + // Rewrite the branches. + auto then_branch = module.lookupSymbol(if_op.then_branch()); + auto else_branch = module.lookupSymbol(if_op.else_branch()); + llvm::SmallDenseMap then_map; + llvm::SmallDenseMap else_map; + + auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional { + auto it = buffer_to_size->find(if_op.getOperand(index + 1)); + if (it == buffer_to_size->end()) return llvm::None; + return it->getFirst().getType(); + }; + auto arg_buffer_size_is_fixed = [&](int64_t index) { + return (*buffer_to_size)[if_op.getOperand(index + 1)].fixed; + }; + OpBuilder builder(if_op); + ModifyFunctionSignature(then_branch, cutil::GetSizeType(builder), &then_map, + find_arg_buffer_type, arg_buffer_size_is_fixed); + ModifyFunctionSignature(else_branch, cutil::GetSizeType(builder), &else_map, + find_arg_buffer_type, arg_buffer_size_is_fixed); + const bool arg_no_changed = then_map.empty(); + if (failed(DecomposeTensorListOpsInternal( + &then_branch.front(), module, &then_map, + decomposed_partitioned_call_callees)) || + failed(DecomposeTensorListOpsInternal( + &else_branch.front(), module, &else_map, + decomposed_partitioned_call_callees))) { + return failure(); + } + auto output_buffer_to_size = + AddTensorListSizesToReturn(then_branch, then_map); + AddTensorListSizesToReturn(else_branch, else_map); + if (output_buffer_to_size.empty() && arg_no_changed) return success(); + // Recreate the If op. + auto new_if_operands = llvm::to_vector<8>(if_op.getOperands()); + auto new_output_shapes = llvm::to_vector<8>(if_op.output_shapes().getValue()); + for (int64_t i = 1; i < if_op.getNumOperands(); ++i) { + auto it = buffer_to_size->find(if_op.getOperand(i)); + if (it == buffer_to_size->end()) continue; + new_if_operands.push_back(it->getSecond().size); + if (!new_output_shapes.empty()) { + // Size is a scalar shape. + tensorflow::TensorShapeProto shape_proto; + new_output_shapes.push_back(builder.getStringAttr( + tensorflow::mangling_util::MangleShape(shape_proto))); + } + } + auto new_if = OpBuilder(if_op).create( + if_op.getLoc(), then_branch.getType().getResults(), new_if_operands, + if_op.getAttrs()); + new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = { + new_if.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + if_op.replaceAllUsesWith( + new_if.getResults().take_front(if_op.getNumResults())); + if_op.erase(); + return success(); +} + +template +LogicalResult HandlePartitionedCallOp( + CallOp call, FuncOp callee, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto emplace_res = decomposed_partitioned_call_callees->try_emplace( + callee, PartitionedCallDecompositionInfo()); + auto& info = emplace_res.first->getSecond(); + // Recreates the call op with info. + auto recreate_caller = [&] { + auto new_operands = llvm::to_vector<8>(call.getOperands()); + for (int64_t i = 0; i < call.getNumOperands(); ++i) { + auto arg_it = info.buffer_arg_to_size_arg.find(i); + if (arg_it == info.buffer_arg_to_size_arg.end()) continue; + auto it = buffer_to_size->find(call.getOperand(i)); + if (it == buffer_to_size->end()) { + call.emitOpError("unknown tensor list."); + return failure(); + } + assert(arg_it->second == new_operands.size()); + new_operands.push_back(it->getSecond().size); + } + OpBuilder builder(call); + auto new_call = builder.create( + call.getLoc(), info.decomposed_callee.getType().getResults(), + new_operands, call.getAttrs()); + new_call.setAttr( + "f", builder.getSymbolRefAttr( + const_cast(info.decomposed_callee).getName())); + for (const auto& entry : info.buffer_ret_to_size_ret) { + (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = { + new_call.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + call.replaceAllUsesWith( + new_call.getResults().take_front(call.getNumResults())); + call.erase(); + return success(); + }; + if (!emplace_res.second) { + // This callee was handled before. + if (!info.signature_change) return success(); + return recreate_caller(); + } + // Rewrite the callee on a cloned function. + llvm::SmallDenseMap callee_map; + auto callee_clone = callee.clone(); + auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional { + auto it = buffer_to_size->find(call.getOperand(index)); + if (it == buffer_to_size->end()) return llvm::None; + return it->getFirst().getType(); + }; + auto arg_buffer_size_is_fixed = [&](int64_t index) { + return (*buffer_to_size)[call.getOperand(index)].fixed; + }; + ModifyFunctionSignature(callee_clone, cutil::GetSizeType(OpBuilder(call)), + &callee_map, find_arg_buffer_type, + arg_buffer_size_is_fixed); + const bool args_no_changed = callee.empty(); + if (failed(DecomposeTensorListOpsInternal( + &callee_clone.front(), module, &callee_map, + decomposed_partitioned_call_callees))) { + return failure(); + } + info.buffer_ret_to_size_ret = + AddTensorListSizesToReturn(callee_clone, callee_map); + if (args_no_changed && info.buffer_ret_to_size_ret.empty()) { + // Signature is not modified. We do not need to keep two copies. + info.signature_change = false; + auto name = callee.getName(); + callee.erase(); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + } else { + info.signature_change = true; + info.decomposed_callee = callee_clone; + for (auto& entry : callee_map) { + auto buffer_arg = entry.getFirst().dyn_cast(); + if (!buffer_arg) continue; + info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] = + entry.getSecond().size.cast().getArgNumber(); + } + + // Add the clone with a new name. + auto name = llvm::join(std::vector{callee.getName().str(), + "tensorlist_decomposed"}, + "_"); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + callee = callee_clone; + } + if (info.signature_change) return recreate_caller(); + return success(); +} + +// Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise, +// returns an error. +LogicalResult GetConstShapeValue(Value shape_value, + llvm::SmallVector* shape) { + auto shape_op = shape_value.getDefiningOp(); + if (!shape_op) return failure(); + auto shape_const_op = llvm::dyn_cast(shape_op); + if (!shape_const_op) return failure(); + for (auto v : shape_const_op.value().getValues()) { + shape->push_back(v.getSExtValue()); + } + return success(); +} + +LogicalResult HandleEmptyTensorListOp( + TF::EmptyTensorListOp list, + llvm::SmallDenseMap* buffer_to_size) { + Value buffer; + OpBuilder builder(list); + llvm::SmallVector element_shape; + if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) { + return list.emitOpError("unknown tensor list element shape"); + } + if (failed(cutil::CreateInitBufferValue( + element_shape, list.max_num_elements(), list, list.element_dtype(), + builder, &buffer))) { + return failure(); + } + Value size = cutil::GetR1Const({0LL}, builder, list.getLoc()); + list.handle().replaceAllUsesWith(buffer); + (*buffer_to_size)[buffer] = {size, /*fixed=*/false}; + list.erase(); + return success(); +} + +LogicalResult HandleTensorListReserveOp( + TF::TensorListReserveOp list, + llvm::SmallDenseMap* buffer_to_size) { + Value buffer; + OpBuilder builder(list); + llvm::SmallVector element_shape; + if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) { + return list.emitOpError("unknown tensor list element shape"); + } + if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(), + list, list.element_dtype(), builder, + &buffer))) { + return failure(); + } + Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(), + list.getLoc()); + (*buffer_to_size)[buffer] = {size, /*fixed=*/true}; + list.handle().replaceAllUsesWith(buffer); + list.erase(); + return success(); +} + +LogicalResult HandleTensorListFromTensorOp( + TF::TensorListFromTensorOp list, + llvm::SmallDenseMap* buffer_to_size) { + OpBuilder builder(list); + Value buffer = builder.create( + list.getLoc(), ArrayRef{list.tensor().getType()}, + ArrayRef{list.tensor()}, ArrayRef{}); + auto type = buffer.getType().cast(); + if (!type.hasStaticShape()) { + return list.emitOpError("TensorListFromTensorOp input has unknown shape."); + } + Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc()); + (*buffer_to_size)[buffer] = {size, /*fixed=*/true}; + list.output_handle().replaceAllUsesWith(buffer); + list.erase(); + return success(); +} + +LogicalResult HandleTensorListPushBackOp( + TF::TensorListPushBackOp push, + llvm::SmallDenseMap* buffer_to_size) { + auto buffer = push.input_handle(); + auto it = buffer_to_size->find(buffer); + if (it == buffer_to_size->end()) { + return push.emitOpError( + "found tf.TensorListPushBack on unknown TensorList."); + } + if (it->getSecond().fixed) { + return push.emitError("cannot push on a fixed-size tensor list"); + } + auto size = it->getSecond().size; + OpBuilder builder(push); + auto new_buffer = + cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc()); + auto new_size = builder.create( + push.getLoc(), ArrayRef{size.getType()}, + ArrayRef{size, cutil::GetR1Const({1LL}, builder, push.getLoc())}, + ArrayRef{}); + push.output_handle().replaceAllUsesWith(new_buffer); + (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false}; + push.erase(); + return success(); +} + +LogicalResult HandleTensorListPopBackOp( + TF::TensorListPopBackOp pop, + llvm::SmallDenseMap* buffer_to_size) { + auto buffer = pop.input_handle(); + auto it = buffer_to_size->find(buffer); + if (it == buffer_to_size->end()) { + pop.emitOpError("found tf.TensorListPopBack on unknown TensorList."); + return failure(); + } + if (it->getSecond().fixed) { + return pop.emitError("cannot pop on a fixed-size tensor list"); + } + auto size = it->getSecond().size; + OpBuilder builder(pop); + auto new_buffer = builder.create( + pop.getLoc(), ArrayRef{buffer.getType()}, ArrayRef{buffer}, + ArrayRef{}); + auto new_size = builder.create( + pop.getLoc(), ArrayRef{size.getType()}, + ArrayRef{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())}, + ArrayRef{}); + auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc()); + pop.output_handle().replaceAllUsesWith(new_buffer); + pop.tensor().replaceAllUsesWith(element); + pop.erase(); + (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false}; + return success(); +} + +LogicalResult HandleTensorListGetItemOp( + TF::TensorListGetItemOp get_item, + const llvm::SmallDenseMap& buffer_to_size) { + auto buffer = get_item.input_handle(); + auto it = buffer_to_size.find(buffer); + if (it == buffer_to_size.end()) { + get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList."); + return failure(); + } + OpBuilder builder(get_item); + auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(), + get_item.getLoc()); + auto element = + cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc()); + get_item.item().replaceAllUsesWith(element); + get_item.erase(); + return success(); +} + +LogicalResult HandleTensorListSetItemOp( + TF::TensorListSetItemOp set_item, + llvm::SmallDenseMap* buffer_to_size) { + auto buffer = set_item.input_handle(); + auto it = buffer_to_size->find(buffer); + if (it == buffer_to_size->end()) { + set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList."); + return failure(); + } + OpBuilder builder(set_item); + auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(), + set_item.getLoc()); + auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder, + set_item.getLoc()); + set_item.output_handle().replaceAllUsesWith(new_buffer); + (*buffer_to_size)[new_buffer] = it->getSecond(); + set_item.erase(); + return success(); +} + +LogicalResult HandleTensorListLengthOp( + TF::TensorListLengthOp length, + const llvm::SmallDenseMap& buffer_to_size) { + auto it = buffer_to_size.find(length.input_handle()); + if (it == buffer_to_size.end()) { + length.emitOpError("found tf.TensorListLength on unknown TensorList."); + return failure(); + } + OpBuilder builder(length); + if (it->getSecond().fixed) { + auto dim = cutil::CreateScalarConst( + length.input_handle().getType().cast().getDimSize(0), + builder, length.getLoc()); + length.length().replaceAllUsesWith(dim); + } else { + auto current_size = it->getSecond().size; + // Reshapes the R1 length to a scalar. + auto reshape = builder.create( + length.getLoc(), + ArrayRef{RankedTensorType::get( + {}, getElementTypeOrSelf(current_size.getType()))}, + ArrayRef{current_size, + cutil::GetR1Const({}, builder, length.getLoc())}, + ArrayRef{}); + length.length().replaceAllUsesWith(reshape); + } + length.erase(); + return success(); +} + +LogicalResult DecomposeTensorListOpsInternal( + Block* block, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + for (auto& op : llvm::make_early_inc_range(block->getOperations())) { + // TODO(yuanzx): Add a pass to remove identities in device computation. + if (llvm::isa(&op) || llvm::isa(&op)) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } else if (auto list = llvm::dyn_cast(&op)) { + if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) { + return failure(); + } + } else if (auto list = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListReserveOp(list, buffer_to_size))) { + return failure(); + } + } else if (auto list = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) { + return failure(); + } + } else if (auto push = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) { + return failure(); + } + } else if (auto pop = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) { + return failure(); + } + } else if (auto get_item = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) { + return failure(); + } + } else if (auto set_item = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) { + return failure(); + } + } else if (auto length = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) { + return failure(); + } + } else if (auto stack = llvm::dyn_cast(&op)) { + stack.tensor().replaceAllUsesWith(stack.input_handle()); + stack.erase(); + } else if (auto addn = llvm::dyn_cast(&op)) { + auto it = buffer_to_size->find(addn.getOperand(0)); + if (it != buffer_to_size->end()) { + addn.sum().setType(addn.getOperand(0).getType()); + (*buffer_to_size)[addn.sum()] = it->getSecond(); + } + } else if (auto zeros = llvm::dyn_cast(&op)) { + if (buffer_to_size->count(zeros.x()) > 0) { + zeros.y().setType(zeros.x().getType()); + (*buffer_to_size)[zeros.y()] = (*buffer_to_size)[zeros.x()]; + } + } else if (auto while_op = llvm::dyn_cast(&op)) { + if (failed(HandleWhileOp(while_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + if (failed(HandleIfOp(if_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto pcall = llvm::dyn_cast(&op)) { + if (!pcall.f().isa()) { + return pcall.emitOpError( + "TensorList decomposition does not support call with nested " + "references."); + } + if (failed(HandlePartitionedCallOp( + pcall, module.lookupSymbol(pcall.f().getRootReference()), + module, buffer_to_size, decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto spcall = + llvm::dyn_cast(&op)) { + if (failed(HandlePartitionedCallOp( + spcall, module.lookupSymbol(spcall.f()), module, + buffer_to_size, decomposed_partitioned_call_callees))) { + return failure(); + } + } + } + return success(); +} + +LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) { + llvm::SmallDenseMap buffer_to_size; + llvm::SmallDenseMap + decomposed_partitioned_call_callees; + return DecomposeTensorListOpsInternal(block, module, &buffer_to_size, + &decomposed_partitioned_call_callees); +} + +void TensorListOpsDecompositionPass::runOnModule() { + auto module = getModule(); + auto main = module.lookupSymbol("main"); + if (!main) return; + if (failed(DecomposeTensorListOps(&main.front(), module))) { + signalPassFailure(); + } +} + +static PassRegistration pass( + "tf-tensor-list-ops-decomposition", + "Decompose tensor list operations into operations on buffers and sizes. " + "Needs static shapes."); + +} // namespace + +namespace TF { +std::unique_ptr> CreateTensorListOpsDecompositionPass() { + return std::make_unique(); +} +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 50059154ed7..7fe65b888d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -244,8 +244,7 @@ void TPUDynamicLayoutPass::runOnFunction() { if (!compile || !compile->getResult(1).hasOneUse()) return; auto compile_launch = llvm::dyn_cast(compile); if (!compile_launch || !compile_launch.WrapsSingleOp() || - compile_launch.GetBody().front().getName().getStringRef() != - "tf._TPUCompileMlir") + !llvm::isa(compile_launch.GetBody().front())) return; executes_and_compiles.emplace_back(execute, compile_launch); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 3b2815ec901..c1419873dba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -41,7 +41,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -136,6 +135,10 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // by inter-island dependencies. Operation* first_read = nullptr; Operation& execute = execute_launch.GetBody().front(); + auto parallel_execute = llvm::dyn_cast( + execute_launch.getParentOp()); + Operation* execute_parent = + parallel_execute ? parallel_execute.getOperation() : execute_launch; // Find inputs that are variable reads. for (auto operand : llvm::enumerate(execute.getOpOperands())) { infos.new_operand_values.push_back(operand.value().get()); @@ -144,9 +147,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( operand.value().get().getDefiningOp()); if (!read_op) continue; if (check_same_region && - read_op.getParentRegion() != execute_launch.getParentRegion()) { + read_op.getParentRegion() != execute_parent->getParentRegion()) continue; - } + auto resource = read_op.resource(); if (check_device) { @@ -193,9 +196,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // work fine for the reads/assigns created by resource lifting, since they are // placed close to the TPUExecute. Operation* last_may_modify_resource_access_before_execute = nullptr; - for (Operation& op : llvm::reverse( - llvm::make_range(std::next(first_read->getIterator()), - execute_launch.getOperation()->getIterator()))) { + for (Operation& op : + llvm::reverse(llvm::make_range(std::next(first_read->getIterator()), + execute_parent->getIterator()))) { if (llvm::dyn_cast(&op)) continue; if (!OpAccessesResource(&op)) continue; last_may_modify_resource_access_before_execute = &op; @@ -232,10 +235,16 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( llvm::SmallPtrSet all_assigns; llvm::SmallVector output_fused(execute_launch.getNumResults(), false); - for (int i = 0; i < execute_launch.getNumResults(); ++i) { + + auto execute_outputs = + parallel_execute + ? parallel_execute.GetRegionOutputs( + execute_launch.getParentRegion()->getRegionNumber()) + : execute_launch.getResults(); + for (auto execute_output : llvm::enumerate(execute_outputs)) { // TODO(lyandy): Handle updates to resource writes by remapping to parent // launch result and checking if launch result is an AssignVariableOp. - auto result = execute_launch.getResult(i); + auto result = execute_output.value(); if (!result.hasOneUse()) continue; auto assign_op = llvm::dyn_cast(*result.user_begin()); if (!assign_op) continue; @@ -250,21 +259,20 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( infos.per_resource_info.shrink_and_clear(); return infos; } - info.execute_output_index = i; + info.execute_output_index = execute_output.index(); info.assign = assign_op; if (!last_assign || last_assign->isBeforeInBlock(assign_op)) { last_assign = assign_op; } all_assigns.insert(assign_op); - output_fused[i] = true; + output_fused[execute_output.index()] = true; } // Check if there are other resource accesses after execute. Operation* first_unknown_resource_access_after_execute = nullptr; if (last_assign) { - for (auto& op : llvm::make_range( - std::next(execute_launch.getOperation()->getIterator()), - last_assign->getIterator())) { + for (auto& op : llvm::make_range(std::next(execute_parent->getIterator()), + last_assign->getIterator())) { if (all_assigns.count(&op) > 0) continue; if (!OpAccessesResource(&op)) continue; first_unknown_resource_access_after_execute = &op; @@ -301,6 +309,115 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( return infos; } +// Appends result types of tf_device.parallel_execute from `start` index region +// (inclusive) to `end` index region (exclusive) to `output_types` and returns +// the number of types added. +int AppendTypes(llvm::SmallVectorImpl* output_types, + tf_device::ParallelExecuteOp parallel_execute, int start, + int end) { + const int size_before = output_types->size(); + for (int index = start; index < end; ++index) { + Block& block = parallel_execute.GetRegionBlockWithIndex(index); + auto terminator_operand_types = block.getTerminator()->getOperandTypes(); + output_types->append(terminator_operand_types.begin(), + terminator_operand_types.end()); + } + return output_types->size() - size_before; +} + +// Replaces TPUExecute with TPUExecuteAndUpdateVariables in a +// tf_device.parallel_execute op. +void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute, + tf_device::LaunchOp execute_launch, + tf_device::LaunchOp merged_execute_launch, + const VariableAccessesForTPUExecute& infos, + OpBuilder* builder) { + Operation* parallel_execute_op = parallel_execute.getOperation(); + + // Collect result types of tf_device.parallel_execute and update region + // result types with the new merged execute result types. + llvm::SmallVector output_types; + const int parallel_execute_num_results = parallel_execute_op->getNumResults(); + output_types.reserve(parallel_execute_num_results); + Region* execute_region = merged_execute_launch.getParentRegion(); + const int region_index = execute_region->getRegionNumber(); + const int num_results_before_region = + AppendTypes(&output_types, parallel_execute, 0, region_index); + // Append updated results from merged execute. + output_types.append(merged_execute_launch.getResultTypes().begin(), + merged_execute_launch.getResultTypes().end()); + const int num_regions = parallel_execute_op->getNumRegions(); + const int num_results_after_region = AppendTypes( + &output_types, parallel_execute, region_index + 1, num_regions); + + builder->setInsertionPoint(parallel_execute); + auto new_parallel_execute = builder->create( + parallel_execute.getLoc(), num_regions, output_types); + + // Replace the uses of the original parallel_execute before region containing + // merged execute. + Operation* new_parallel_execute_op = new_parallel_execute.getOperation(); + for (int i = 0; i < num_results_before_region; ++i) + parallel_execute_op->getResult(i).replaceAllUsesWith( + new_parallel_execute_op->getResult(i)); + + // Replace the uses of the original parallel_execute after region containing + // merged execute. The number of results changed in the region containing the + // merged execute, but they should match, so results are replaced starting + // from the ends of both parallel_execute. + const int new_parallel_execute_num_results = + new_parallel_execute_op->getNumResults(); + for (int i = 0; i < num_results_after_region; ++i) + parallel_execute_op->getResult(parallel_execute_num_results - i - 1) + .replaceAllUsesWith(new_parallel_execute_op->getResult( + new_parallel_execute_num_results - i - 1)); + + // Replace the uses of the original parallel_execute for the region containing + // the merged execute. + auto old_region_results = parallel_execute.GetRegionOutputs(region_index); + for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + if (infos.old_to_new_output_mapping[i] < 0) continue; + old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult( + infos.old_to_new_output_mapping[i] + num_results_before_region)); + } + + // Replace original terminator with new terminator for returning merged + // execute results. + Operation* old_terminator = execute_region->front().getTerminator(); + builder->setInsertionPointToEnd(&execute_region->front()); + builder->create(old_terminator->getLoc(), + merged_execute_launch.getResults()); + old_terminator->erase(); + + // Remove the original TPUExecute op. + execute_launch.erase(); + + // Move all regions from old parallel_execute to new parallel_execute. + for (auto region : llvm::zip(new_parallel_execute_op->getRegions(), + parallel_execute_op->getRegions())) + std::get<0>(region).takeBody(std::get<1>(region)); + + // Remove the original parallel_execute. + parallel_execute_op->dropAllUses(); + parallel_execute.erase(); +} + +// Replaces TPUExecute with TPUExecuteAndUpdateVariables. +void ReplaceExecute(tf_device::LaunchOp execute_launch, + tf_device::LaunchOp merged_execute_launch, + const VariableAccessesForTPUExecute& infos) { + // Replace the uses. + for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + if (infos.old_to_new_output_mapping[i] < 0) continue; + execute_launch.getResult(i).replaceAllUsesWith( + merged_execute_launch.getResult(infos.old_to_new_output_mapping[i])); + } + + // Remove the original TPUExecute op. + execute_launch.getOperation()->dropAllUses(); + execute_launch.erase(); +} + // Merges the variable accesses into one TPUExecute op. void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch, bool check_device, bool check_same_region, @@ -352,19 +469,19 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch, merged_execute.getOperation()->moveBefore( merged_execute_launch.GetBody().getTerminator()); - // Replace the uses. - for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { - if (infos.old_to_new_output_mapping[i] < 0) continue; - execute_launch.getResult(i).replaceAllUsesWith( - merged_execute_launch.getResult(infos.old_to_new_output_mapping[i])); - } + if (auto parallel_execute = llvm::dyn_cast( + execute_launch.getParentOp())) + ReplaceParallelExecute(parallel_execute, execute_launch, + merged_execute_launch, infos, builder); + else + ReplaceExecute(execute_launch, merged_execute_launch, infos); + // Remove the assign ops. for (const auto& entry : infos.per_resource_info) { const auto& info = entry.getSecond(); if (info.assign) info.assign->erase(); } - // Remove the original TPUExecute op. - execute_launch.erase(); + // Remove the read ops if they have no more uses. for (const auto& entry : infos.per_resource_info) { const auto& info = entry.getSecond(); @@ -372,25 +489,43 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch, } } +// Checks if an ops parent is a tf_device.parallel_execute and the region the +// op is in is perfectly wrapped. +bool ParentParallelExecuteWrapsSingleOp(Operation* op) { + auto parallel_execute = + llvm::dyn_cast(op->getParentOp()); + if (!parallel_execute) return true; + + return parallel_execute.RegionWrapsSingleOp( + op->getParentRegion()->getRegionNumber()); +} + void TPUMergeVariablesWithExecutePass::runOnFunction() { // Find all the executes first, since we will mutate the nodes around each // execute. llvm::SmallVector execute_launches; getFunction().walk([&](tf_device::LaunchOp op) { - if (op.WrapsSingleOp() && llvm::isa(op.GetBody().front())) + if (op.WrapsSingleOp() && + llvm::isa(op.GetBody().front()) && + ParentParallelExecuteWrapsSingleOp(op)) execute_launches.push_back(op); }); for (auto execute_launch : execute_launches) { OpBuilder builder(&getContext()); const bool parent_is_replicate = - llvm::isa(execute_launch.getParentOp()); + llvm::isa(execute_launch.getParentOp()) || + (llvm::isa( + execute_launch.getParentOp()) && + llvm::isa( + execute_launch.getParentOp()->getParentOp())); + // If this is inside a tf_device::ReplicateOp, the variables are guaranteed // to be on the same device as the TPUExecute op. Skip device checking in // that case, but we need to check that we are only merging reads/assigns // that are also in this replicated region. - MergeForOneTPUExecute(execute_launch, !parent_is_replicate, - parent_is_replicate, &builder); + MergeForOneTPUExecute(execute_launch, /*check_device=*/!parent_is_replicate, + /*check_same_region=*/parent_is_replicate, &builder); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 7b0291a2f9b..50b6555076d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -66,10 +67,21 @@ constexpr char kNumReplicasAttr[] = "num_replicas"; constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kPaddingMapAttr[] = "padding_map"; +constexpr char kTopologyAttr[] = "topology"; +constexpr char kDeviceAssignmentAttr[] = "device_assignment"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; +constexpr char kBadStringArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not a string"; +constexpr char kBadIntArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not an int"; +constexpr char kBadArrayElementMsg[] = + "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; +constexpr char kBadArrayAttrLengthMsg[] = + "bad '{0}' attribute, expected array attribute of size {1}, got size {2}"; + // Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU // jit-compile runtime ops. // @@ -150,17 +162,37 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return success(); } -// Populates a TPUCompileMetadataProto from attributes of a -// `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the -// op, a failure will be returned. -// TODO(lyandy): Support session handle and guaranteed consts. -LogicalResult SetMetadataProtoFromLaunchFuncOp( - tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica, - llvm::Optional&& xla_device_assignment, - tensorflow::tpu::TPUCompileMetadataProto* metadata) { - metadata->set_num_replicas(num_replicas); - metadata->set_num_cores_per_replica(num_cores_per_replica); +// Extracts device coordinates from a device assignment attribute on an op. +LogicalResult GetDeviceCoordinates( + tf_device::LaunchFuncOp op, + llvm::SmallVectorImpl* device_assignment) { + auto device_assignment_attr = + op.getAttrOfType(kDeviceAssignmentAttr); + if (!device_assignment_attr) + return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr)); + device_assignment->reserve(device_assignment_attr.size()); + + for (auto device_coordinate_and_idx : + llvm::enumerate(device_assignment_attr)) { + auto device_coordinate = + device_coordinate_and_idx.value().dyn_cast(); + if (!device_coordinate) + return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg, + kDeviceAssignmentAttr, + device_coordinate_and_idx.index())); + + device_assignment->push_back(device_coordinate.getInt()); + } + + return success(); +} + +// Populates a TPUCompileMetadataProto with StepMarkerLocation from a +// `tf_device::LaunchFuncOp`. +LogicalResult SetMetadataProtoStepMarkerLocation( + tf_device::LaunchFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto step_marker_location = op.getAttrOfType(kStepMarkerLocationAttr); if (!step_marker_location) @@ -179,6 +211,14 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp( metadata->set_step_marker_location(location); + return success(); +} + +// Populates a TPUCompileMetadataProto with PaddingMap from a +// `tf_device::LaunchFuncOp`. +LogicalResult SetMetadataProtoPaddingMap( + tf_device::LaunchFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto padding_map = op.getAttrOfType(kPaddingMapAttr); if (!padding_map) return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr)); @@ -187,25 +227,56 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp( auto& padding_attr = padding_and_idx.value(); auto padding_attr_str = padding_attr.dyn_cast(); if (!padding_attr_str) - return op.emitOpError( - llvm::formatv("bad '{0}' attribute at index {1}, not a string", - kPaddingMapAttr, padding_and_idx.index())); + return op.emitOpError(llvm::formatv( + kBadStringArrayElementMsg, kPaddingMapAttr, padding_and_idx.index())); tensorflow::tpu::PaddingMap* padding = metadata->mutable_padding_maps()->Add(); if (!padding->ParseFromString(std::string(padding_attr_str.getValue()))) return op.emitOpError(llvm::formatv( - "bad '{0}' attribute at index {1} with value '{2}'", kPaddingMapAttr, - padding_and_idx.index(), padding_attr_str.getValue())); + kBadArrayElementMsg, kPaddingMapAttr, padding_and_idx.index(), + padding_attr_str.getValue(), "tpu::PaddingMap")); } - if (xla_device_assignment.hasValue()) - *metadata->mutable_device_assignment() = - std::move(xla_device_assignment.getValue()); + return success(); +} + +// Parses a xla::OpSharding from a string attribute. +LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, + int index, xla::OpSharding* sharding) { + auto sharding_str = attr.dyn_cast(); + if (!sharding_str) + return op->emitOpError( + llvm::formatv(kBadStringArrayElementMsg, name, index)); + + if (!sharding->ParseFromString(sharding_str.getValue().str())) + return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index, + sharding_str.getValue(), + "xla::OpSharding")); + + return success(); +} + +// Populates a TPUCompileMetadataProto with argument types and sharding from a +// `tf_device::LaunchFuncOp`. +LogicalResult SetMetadataProtoArgs( + tf_device::LaunchFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + auto input_shardings = + op.getAttrOfType(tensorflow::kInputShardingAttr); + if (!input_shardings) + return op.emitOpError( + CreateMissingAttributeMsg(tensorflow::kInputShardingAttr)); + + if (input_shardings.size() != op.getNumOperands()) + return op.emitOpError( + llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr, + op.getNumOperands(), input_shardings.size())); // Set args metadata in proto. for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) { Type operand_type = operand_type_and_idx.value(); + int index = operand_type_and_idx.index(); tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args(); tensorflow::DataType dtype; tensorflow::Status status = @@ -213,7 +284,7 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp( if (!status.ok()) return op.emitOpError( llvm::formatv("failed to determine operand type at index {0}: {1}", - operand_type_and_idx.index(), status.error_message())); + index, status.error_message())); arg->set_dtype(dtype); // TODO(lyandy): Support other arg kinds. @@ -232,29 +303,67 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp( arg->mutable_shape()->set_unknown_rank(true); } - // TODO(lyandy): Determine proper sharding of args once topology and devices - // are propagated to the pass. - xla::OpSharding sharding; - sharding.set_type(xla::OpSharding::MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(0); - *arg->mutable_sharding() = std::move(sharding); - } - - // Set retvals metadata in proto. - // TODO(lyandy): Determine proper sharding of retvals once topology and - // devices is propagated to the pass. - for (int i = 0; i < op.getNumResults(); ++i) { - xla::OpSharding sharding; - sharding.set_type(xla::OpSharding::MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(0); - *metadata->add_retvals()->mutable_sharding() = std::move(sharding); + if (failed(SetOpSharding(op, input_shardings.getValue()[index], + tensorflow::kInputShardingAttr, index, + arg->mutable_sharding()))) + return failure(); } return success(); } +// Populates a TPUCompileMetadataProto with result sharding from a +// `tf_device::LaunchFuncOp`. +LogicalResult SetMetadataProtoRetvals( + tf_device::LaunchFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + auto output_shardings = + op.getAttrOfType(tensorflow::kOutputShardingAttr); + if (!output_shardings) + return op.emitOpError( + CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr)); + + if (output_shardings.size() != op.getNumResults()) + return op.emitOpError( + llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr, + op.getNumResults(), output_shardings.size())); + + // Set retvals metadata in proto. + for (auto output_sharding_and_idx : llvm::enumerate(output_shardings)) + if (failed(SetOpSharding(op, output_sharding_and_idx.value(), + tensorflow::kOutputShardingAttr, + output_sharding_and_idx.index(), + metadata->add_retvals()->mutable_sharding()))) + return failure(); + + return success(); +} + +// Populates a TPUCompileMetadataProto from attributes of a +// `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the +// op, a failure will be returned. +// TODO(lyandy): Support session handle and guaranteed consts. +LogicalResult SetMetadataProtoFromLaunchFuncOp( + tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica, + llvm::Optional&& xla_device_assignment, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + metadata->set_num_replicas(num_replicas); + metadata->set_num_cores_per_replica(num_cores_per_replica); + + if (failed(SetMetadataProtoStepMarkerLocation(op, metadata))) + return failure(); + + if (failed(SetMetadataProtoPaddingMap(op, metadata))) return failure(); + + if (xla_device_assignment.hasValue()) + *metadata->mutable_device_assignment() = + std::move(xla_device_assignment.getValue()); + + if (failed(SetMetadataProtoArgs(op, metadata))) return failure(); + + return SetMetadataProtoRetvals(op, metadata); +} + // Wraps single op in `tf_device.launch` for explicit device assignment. tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op, llvm::StringRef device) { @@ -282,9 +391,6 @@ Operation* BuildCompileOp( int num_cores_per_replica, llvm::StringRef compilation_device, llvm::Optional&& xla_device_assignment, OpBuilder* builder) { - // TODO(b/139377366): Use tf_tpu.compile build method when it is defined. - OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir"); - // Set metadata from attributes. tensorflow::tpu::TPUCompileMetadataProto metadata; if (failed(SetMetadataProtoFromLaunchFuncOp( @@ -298,9 +404,6 @@ Operation* BuildCompileOp( else metadata.SerializeToString(&txt_metadata); - compile_op_state.addAttribute("metadata", - builder->getStringAttr(txt_metadata)); - // Build a shape op for each input to launch_func. // TODO(b/139377366): When shape inference is ready, we can use compile time // shape inference to get inputs that have static shapes and only use shape @@ -320,63 +423,77 @@ Operation* BuildCompileOp( operand_and_idx.value()); compile_op_operands.emplace_back(shape_op.getResult()); } - compile_op_state.addOperands(compile_op_operands); - compile_op_state.addAttribute( - "NumDynamicShapes", - builder->getI64IntegerAttr(compile_op_operands.size())); - FlatSymbolRefAttr func_attr = - launch_func.getAttrOfType("func"); - if (!func_attr) { - launch_func.emitOpError("does not have `func` attribute"); - return nullptr; - } + FlatSymbolRefAttr func_attr = launch_func.funcAttr(); FuncOp func = launch_func.getParentOfType().lookupSymbol( func_attr.getValue()); std::string txt_module; if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr; - compile_op_state.addAttribute("mlir_module", - builder->getStringAttr(txt_module)); - // Result #0 is a string indicating whether compilation is successful or not. - compile_op_state.addTypes( - RankedTensorType::get({}, builder->getType())); + auto result_type = + RankedTensorType::get({}, builder->getType()); - // Result #1 is key to look up executable binary in compilation cache. - compile_op_state.addTypes( - RankedTensorType::get({}, builder->getType())); + auto compile_op = builder->create( + launch_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ + llvm::SmallVector(num_cores_per_replica, result_type), + compile_op_operands, txt_module, txt_metadata); - Operation* compile_op = builder->createOperation(compile_op_state); - - return WrapOpInLaunch(builder, compile_op->getLoc(), compile_op, + return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op, compilation_device); } -// Creates a `tf.TPUExecute` op that executes TPU program generated by -// `compile_op`. -Operation* BuildExecuteOp(Operation* compile_op, - tf_device::LaunchFuncOp launch_func, - OpBuilder* builder) { - // TPUExecute inherits all launch_func inputs, and takes an additional input - // for compilation cache key. - llvm::SmallVector tensor_inputs(launch_func.getOperands()); - tensor_inputs.push_back(compile_op->getResult(1)); +// Assigns explicit devices to replicate op. An aliased device is created per +// core, and all replica devices per core are grouped together. +void AssignDevicesToReplicate( + tf_device::ReplicateOp replicate, + llvm::ArrayRef> execution_devices, + OpBuilder* builder) { + if (!replicate) return; + const int num_replicas = execution_devices.size(); + const int num_cores_per_replica = execution_devices.front().size(); + + llvm::SmallVector device_attrs; + for (int core = 0; core < num_cores_per_replica; ++core) { + llvm::SmallVector devices_by_core; + devices_by_core.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) + devices_by_core.push_back(execution_devices[replica][core]); + + device_attrs.push_back( + builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core), + builder->getStrArrayAttr(devices_by_core))); + } + + replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); +} + +// Creates a `tf.TPUExecute` op that executes TPU program. +Operation* BuildExecuteOp( + const int core_id, llvm::ArrayRef output_sharding_config, + llvm::ArrayRef inputs, tf_device::LaunchFuncOp launch_func, + OpBuilder* builder) { // TODO(b/139377366): Need to snapshot all resource variable inputs in // follow-up CLs. + auto output_types = tensorflow::GetOutputTypesForLogicalDeviceComputation( + core_id, output_sharding_config, launch_func); + // TPUExecute has same output types as launch_func. - return builder->create( - launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs, - llvm::ArrayRef{}); + return builder->create(launch_func.getLoc(), output_types, + inputs, + llvm::ArrayRef{}); } // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. -Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, - tf_device::LaunchFuncOp launch_func, - OpBuilder* builder) { +tf_device::ParallelExecuteOp BuildParallelExecuteOp( + llvm::ArrayRef> execution_devices, + llvm::ArrayRef output_sharding_config, + Operation* compile_op, tf_device::LaunchFuncOp launch_func, + OpBuilder* builder) { + const int num_cores_per_replica = execution_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. // @@ -385,18 +502,28 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, const auto& launch_result_types = launch_func.getResultTypes(); llvm::SmallVector concatenated_output_types; concatenated_output_types.reserve(launch_result_types.size() * - num_logical_cores); + num_cores_per_replica); - for (int core_id = 0; core_id < num_logical_cores; ++core_id) - for (Type t : launch_result_types) - concatenated_output_types.emplace_back(t); + for (int core = 0; core < num_cores_per_replica; ++core) { + auto output_types = tensorflow::GetOutputTypesForLogicalDeviceComputation( + core, output_sharding_config, launch_func); + for (Type t : output_types) concatenated_output_types.emplace_back(t); + } auto parallel_execute_op = builder->create( - launch_func.getLoc(), num_logical_cores, concatenated_output_types); + launch_func.getLoc(), num_cores_per_replica, concatenated_output_types); + // Extract inputs for each region of the parallel_execute op. The i-th + // element in the list represents the input lists to TPU computation for + // i-th logical core. + auto input_list = tensorflow::ExtractInputsForLogicalDevices( + num_cores_per_replica, launch_func); + + const bool replicated = execution_devices.size() != 1; // For each logical core, create a region with TPUExecute op. - for (int core_id = 0; core_id < num_logical_cores; ++core_id) { - auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id); + assert(input_list.size() == num_cores_per_replica); + for (int core = 0; core < num_cores_per_replica; ++core) { + auto& region = parallel_execute_op.GetRegionBlockWithIndex(core); builder->setInsertionPointToEnd(®ion); // Create Execute op. @@ -404,14 +531,21 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, // TODO(b/148913294): Identify inputs/return values specific to each // logical core TPU execution by parsing xla_sharding op in // launch_func. - auto execute = BuildExecuteOp(compile_op, launch_func, builder); + auto execute_inputs = input_list[core]; + execute_inputs.emplace_back(compile_op->getResult(core + 1)); - // Create a launch op for each region of parallel_execute. - // - // TODO(b/149102679): Add device attribute to launch op once device - // topology for multiple logical cores can be correctly parsed. - auto region_launch_op = WrapOpInLaunch( - builder, region.getParent()->getLoc(), execute, /*device=*/""); + auto execute = BuildExecuteOp(core, output_sharding_config, execute_inputs, + launch_func, builder); + + // If computation is replicated, use aliased device. Otherwise there is only + // one execution device per core and the device is assigned to the execute + // op. + std::string device = replicated + ? tensorflow::GetDeviceAliasForLogicalCore(core) + : execution_devices.front()[core]; + + auto region_launch_op = + WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device); builder->create(region.getParent()->getLoc(), region_launch_op.getResults()); @@ -420,43 +554,14 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, return parallel_execute_op; } -// As tf_device.parallel_execute wraps # logical cores number of TPUExecute -// ops, the number of return values of parallel_execute op exceeds that of -// launch_func op. As so, each return value of parallel_execute op must be -// mapped with corresponding return value usages of launch_func. -// -// TODO(b/148913294): Once argument and return value sharding of tpu computation -// is determined, correctly map outputs of parallel_execute op. -void RemapOutputsOfParallelExecute(tf_device::LaunchFuncOp launch_func, - Operation* op) { - for (auto outputs : llvm::zip(launch_func.getResults(), op->getResults())) - std::get<0>(outputs).replaceAllUsesWith(std::get<1>(outputs)); -} - tf_device::LaunchOp AssignDevicesToReplicatedExecute( llvm::ArrayRef> execution_devices, - tf_device::ReplicateOp replicate, Operation* execute_op, - OpBuilder* builder) { - // If computation is replicated, execution devices are assigned to the - // replicate. Otherwise there is only one execution device and the device is - // assigned to the execute op. - std::string device; - if (replicate) { - // Model parallelism is not support for now. Therefore, assign all ops - // in replicate op with virtual device alias specifying that ops will be - // executed on the zeroth core. - llvm::SmallVector replicate_execution_devices; - replicate_execution_devices.reserve(execution_devices.size()); - for (const auto& replica_execution_devices : execution_devices) - replicate_execution_devices.push_back(replica_execution_devices.front()); - - device = tensorflow::GetDeviceAliasForLogicalCore(0); - auto device_attr = builder->getNamedAttr( - device, builder->getStrArrayAttr(replicate_execution_devices)); - replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attr)); - } else { - device = execution_devices.front().front(); - } + Operation* execute_op, OpBuilder* builder) { + const bool replicated = execution_devices.size() != 1; + // If computation is replicated, use aliased device. Otherwise there is only + // one execution device and the device is assigned to the execute op. + std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0) + : execution_devices.front().front(); return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device); } @@ -466,10 +571,8 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute( void BuildTPUCompileSucceededAssertOp(Operation* compile_op, llvm::StringRef compilation_device, OpBuilder* builder) { - OperationState assert_op_state(compile_op->getLoc(), - "tf.TPUCompileSucceededAssert"); - assert_op_state.addOperands(compile_op->getResult(0)); - Operation* assert_op = builder->createOperation(assert_op_state); + auto assert_op = builder->create( + compile_op->getLoc(), compile_op->getResult(0)); WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device); } @@ -551,11 +654,19 @@ LogicalResult Rewrite( int num_cores_per_replica = num_cores_per_replica_attr.getInt(); + auto topology_attr = launch_func.getAttrOfType(kTopologyAttr); + if (!topology_attr) + return launch_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + + llvm::SmallVector device_assignment; + if (failed(GetDeviceCoordinates(launch_func, &device_assignment))) + return failure(); + // Determine compilation and execution devices. auto status_or_tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( - devices, num_replicas, num_cores_per_replica, /*topology_attr=*/"", - /*device_assignment_attr=*/{}); + devices, num_replicas, num_cores_per_replica, + topology_attr.getValue(), device_assignment); if (!status_or_tpu_device_assignment.ok()) return launch_func.emitError() << "error in fetching TPU compilation/execution devices: " @@ -581,21 +692,35 @@ LogicalResult Rewrite( BuildTPUCompileSucceededAssertOp( compile_op, tpu_device_assignment.compilation_device, builder); + AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices, + builder); + + llvm::SmallVector output_shardings; + auto result = tensorflow::ParseAndValidateOutputSharding(launch_func, + &output_shardings); + if (failed(result)) return failure(); + if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. - Operation* execute_op = BuildParallelExecuteOp( - num_cores_per_replica, compile_op, launch_func, builder); + tf_device::ParallelExecuteOp execute_op = BuildParallelExecuteOp( + tpu_device_assignment.execution_devices, output_shardings, compile_op, + launch_func, builder); - RemapOutputsOfParallelExecute(launch_func, execute_op); - - // TODO(hongjunchoi): Correctly parse TPU topology and assign logical device - // attributes to launch_op's within parallel_execute op. + // As tf_device.parallel_execute wraps # logical cores number of TPUExecute + // ops, the number of return values of parallel_execute op exceeds that of + // launch_func op. As so, each return value of parallel_execute op must be + // mapped with corresponding return value usages of launch_func. + tensorflow::RemapOutputsFromLogicalDevices(output_shardings, launch_func, + execute_op); } else { - Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder); + llvm::SmallVector execute_inputs(launch_func.getOperands()); + execute_inputs.emplace_back(compile_op->getResult(1)); + + Operation* execute_op = BuildExecuteOp( + /*core_id=*/0, output_shardings, execute_inputs, launch_func, builder); tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( - tpu_device_assignment.execution_devices, replicate, execute_op, - builder); + tpu_device_assignment.execution_devices, execute_op, builder); launch_func.replaceAllUsesWith(launch_op); } @@ -605,13 +730,14 @@ LogicalResult Rewrite( } void TPURewritePass::runOnModule() { - llvm::SmallVector devices; + mlir::TF::RuntimeDevices devices; if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices))) return signalPassFailure(); OpBuilder builder(&getContext()); auto result = getModule().walk([&](tf_device::LaunchFuncOp op) { - if (failed(Rewrite(op, devices, &builder))) return WalkResult::interrupt(); + if (failed(Rewrite(op, devices.device_names(), &builder))) + return WalkResult::interrupt(); return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 244df85f482..c9838ff9651 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,10 +35,6 @@ namespace mlir { namespace TFTPU { namespace { -constexpr char kXlaShardingAttr[] = "_XlaSharding"; -constexpr char kInputShardingAttr[] = "input_sharding_configuration"; -constexpr char kOutputShardingAttr[] = "output_sharding_configuration"; - struct TPUShardingIdentificationPass : public ModulePass { void runOnModule() override; @@ -68,13 +65,6 @@ void GetAdjacentToXlaShardingOp( } } -llvm::Optional ParseShardingAttribute(Operation* operation) { - const auto& sharding_attr = - operation->getAttrOfType(kXlaShardingAttr); - if (!sharding_attr) return llvm::Optional(); - return sharding_attr.getValue(); -} - // Parse XlaSharding op connected to input args. If Input to // tf_device.LaunchFunc op is of resource type, then XlaSharding op // will be connected to following ReadVariable op. @@ -97,7 +87,7 @@ llvm::Optional ParseInputSharding(const FuncOp func, } if (!parsed_sharding_op) return llvm::Optional(); - return ParseShardingAttribute(parsed_sharding_op->getOperation()); + return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation()); } // If operand of return value of tf_device.LaunchFunc op is directly from @@ -105,9 +95,9 @@ llvm::Optional ParseInputSharding(const FuncOp func, llvm::Optional ParseReturnValueSharding(FuncOp func, const int output_index, const OpOperand& operand) { - if (auto sharding_op = - llvm::dyn_cast(operand.get().getDefiningOp())) { - return ParseShardingAttribute(sharding_op.getOperation()); + if (auto sharding_op = llvm::dyn_cast_or_null( + operand.get().getDefiningOp())) { + return tensorflow::ParseShardingAttribute(sharding_op.getOperation()); } return llvm::Optional(); @@ -153,8 +143,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { if (!input_arg_sharding.hasValue()) continue; sharding_for_args[arg_index] = input_arg_sharding->str(); } - SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr, - sharding_for_args); + SetShardingConfigurationAsAttribute( + launch_func, tensorflow::kInputShardingAttr, sharding_for_args); // By default return values from logical core 0 is used if no sharding // configuration is defined. @@ -176,8 +166,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { sharding_for_return_values[return_value_index] = return_val_sharding->str(); } - SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr, - sharding_for_return_values); + SetShardingConfigurationAsAttribute( + launch_func, tensorflow::kOutputShardingAttr, sharding_for_return_values); } void TPUShardingIdentificationPass::runOnModule() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 26d1f75b382..6e698c3ca5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -442,8 +442,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, if (!compile) return; auto compile_launch = llvm::dyn_cast(compile); if (!compile_launch || !compile_launch.WrapsSingleOp() || - compile_launch.GetBody().front().getName().getStringRef() != - "tf._TPUCompileMlir") + !llvm::isa(compile_launch.GetBody().front())) return; auto module = while_op.getParentOfType(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index 912a6aa722f..27939cba63c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -25,8 +25,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project -#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/OpImplementation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3d5355ba92a..366403e0654 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2093,7 +2093,7 @@ Status GraphDefImporter::GetControlRetsFromFunctionGraph( // Stateful helper class to import a TensorFlow model expressed in SavedModel // into an MLIR Module. -class SavedModelImporter : public ImporterBase { +class SavedModelObjectGraphImporter : public ImporterBase { public: // Main entry point: converts all functions in the given meta graph to an MLIR // Module. @@ -2102,7 +2102,7 @@ class SavedModelImporter : public ImporterBase { absl::Span exported_names, bool add_default_attributes); private: - explicit SavedModelImporter( + explicit SavedModelObjectGraphImporter( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, const GraphImportConfig& specs, mlir::ModuleOp module, std::unordered_map* tf_name_to_mlir_name, @@ -2799,7 +2799,7 @@ Status CreateSavedModelIR( return Status::OK(); } -StatusOr SavedModelImporter::Convert( +StatusOr SavedModelObjectGraphImporter::Convert( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; @@ -2828,8 +2828,9 @@ StatusOr SavedModelImporter::Convert( ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); NameUniquifier function_name_uniquifier(graph.flib_def()); - SavedModelImporter importer(graph.flib_def(), debug_info, specs, module.get(), - &tf_name_to_mlir_name, &function_name_uniquifier); + SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs, + module.get(), &tf_name_to_mlir_name, + &function_name_uniquifier); auto fn_names = graph.flib_def().ListFunctionNames(); for (const auto& fn_name : fn_names) { @@ -2870,20 +2871,20 @@ StatusOr SavedModelImporter::Convert( // A helper class to import a TensorFlow model expressed in SavedModel V1 into // an MLIR Module in SavedModel dialect. -class SavedModelV1Importer { +class SavedModelSignatureDefImporter { public: // Main entry point: converts all functions (specified by SignatureDefs) in // the given meta graph to an MLIR Module. static StatusOr Convert(const SavedModelBundle& bundle, mlir::MLIRContext* context) { - SavedModelV1Importer importer(bundle, context); + SavedModelSignatureDefImporter importer(bundle, context); return importer.ConvertSignatures(); } private: - SavedModelV1Importer(const SavedModelBundle& bundle, - mlir::MLIRContext* context) + SavedModelSignatureDefImporter(const SavedModelBundle& bundle, + mlir::MLIRContext* context) : bundle_(bundle), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} @@ -2919,7 +2920,8 @@ class SavedModelV1Importer { mlir::OwningModuleRef module_; }; -StatusOr SavedModelV1Importer::ConvertSignatures() { +StatusOr +SavedModelSignatureDefImporter::ConvertSignatures() { const auto& signatures = bundle_.GetSignatures(); const auto& graphdef = bundle_.meta_graph_def.graph_def(); PopulateTfVersions(module_.get(), graphdef.versions()); @@ -2958,7 +2960,7 @@ StatusOr SavedModelV1Importer::ConvertSignatures() { return std::move(module_); } -Status SavedModelV1Importer::ConvertSignature( +Status SavedModelSignatureDefImporter::ConvertSignature( const GraphDef& graphdef, const std::string& sig_def_key, const std::map& inputs_sorted, const std::map& outputs_sorted, @@ -3022,7 +3024,7 @@ Status SavedModelV1Importer::ConvertSignature( return Status::OK(); } -Status SavedModelV1Importer::LiftVariables() { +Status SavedModelSignatureDefImporter::LiftVariables() { llvm::SmallVector ops; bool contains_ref_variable = false; @@ -3047,7 +3049,7 @@ Status SavedModelV1Importer::LiftVariables() { return Status::OK(); } -void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) { +void SavedModelSignatureDefImporter::LiftVariable(mlir::TF::VarHandleOp op) { mlir::OpBuilder builder(&module_->getBodyRegion()); auto func_op = op.getParentOfType(); @@ -3077,7 +3079,7 @@ void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) { op.getOperation()->erase(); } -Status SavedModelV1Importer::ReadVariablesFromSession( +Status SavedModelSignatureDefImporter::ReadVariablesFromSession( const llvm::SmallVectorImpl& ops) { mlir::OpBuilder builder(&module_->getBodyRegion()); @@ -3140,7 +3142,7 @@ Status SavedModelV1Importer::ReadVariablesFromSession( return Status::OK(); } -GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays( +GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays( const std::map& inputs) { GraphImportConfig::InputArrays results; for (const auto& iter : inputs) { @@ -3162,7 +3164,7 @@ GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays( return results; } -std::vector SavedModelV1Importer::ParseOutputArrays( +std::vector SavedModelSignatureDefImporter::ParseOutputArrays( const std::map& outputs) { std::vector results; for (const auto& iter : outputs) { @@ -3217,13 +3219,13 @@ StatusOr ConvertGraphToMlir( StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { - return SavedModelImporter::Convert(saved_model, context, exported_names, - add_default_attributes); + return SavedModelObjectGraphImporter::Convert( + saved_model, context, exported_names, add_default_attributes); } StatusOr ConvertSavedModelV1ToMlir( const SavedModelBundle& saved_model, mlir::MLIRContext* context) { - return SavedModelV1Importer::Convert(saved_model, context); + return SavedModelSignatureDefImporter::Convert(saved_model, context); } std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 5e958960d07..d5fcf86cc93 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -116,7 +116,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction( return module_or.ConsumeValueOrDie(); } -mlir::OwningModuleRef SavedModelToMlirImport( +mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context) { @@ -137,7 +137,7 @@ mlir::OwningModuleRef SavedModelToMlirImport( return module_or.ConsumeValueOrDie(); } -mlir::OwningModuleRef SavedModelV1ToMlirImport( +mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 0380e1165a7..76bada96845 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -52,7 +52,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( // Converts a TensorFlow SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the // given MLIR `context`. -mlir::OwningModuleRef SavedModelToMlirImport( +mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context); @@ -60,7 +60,7 @@ mlir::OwningModuleRef SavedModelToMlirImport( // Converts a TensorFlow V1 SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the // given MLIR `context`. -mlir::OwningModuleRef SavedModelV1ToMlirImport( +mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, mlir::MLIRContext* context); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 35a84481851..10aad0a03ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const xla::CustomShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -88,8 +89,12 @@ Status GetXlaInputShapes( TF_ASSIGN_OR_RETURN(xla_shape, shape_representation_fn(arg_shapes[i], dtype)); } - xla_input_shapes->push_back( - xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); + if (use_tuple_args) { + xla_input_shapes->push_back( + xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); + } else { + *xla_input_shapes = individual_arg_shapes; + } return Status::OK(); } @@ -210,6 +215,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, bool use_tuple_args, bool return_tuple) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); + tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); @@ -222,13 +228,17 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // and canonicalization opportunities that are necessary for the second // LegalizeTFPass(allow_partial_conversion=false) invocation. tf2xla.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(true)); - tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false)); - if (VLOG_IS_ON(1)) - tf2xla.enableIRPrinting(std::make_unique()); + if (VLOG_IS_ON(1)) { + // Print the whole module after each pass which requires disabling + // multi-threading as well. + tf2xla.disableMultithreading(); + tf2xla.enableIRPrinting(std::make_unique( + /*print_module_scope=*/true)); + } // Make sure we catch any error reported by MLIR and forward it to the TF // error reporting system. Report a generic error if pass manager failed @@ -252,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { mlir::MLIRContext mlir_context; @@ -273,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo( // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( - module_op, compilation_result->computation.get(), /*use_tuple_args=*/true, + module_op, compilation_result->computation.get(), use_tuple_args, /*return_tuple=*/true)); // Construct mapping from XlaComputation's arg to input edges of execute @@ -286,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo( }; // Compute all input shapes. - TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, + TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, shape_representation_fn_no_fast_memory, &compilation_result->xla_input_shapes)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index ed25aaf929e..41fa8b90e4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 8e0f9cb2497..b258dd68ae1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; - Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes, - TestShapeRepresentation, - &compilation_result); + Status s = CompileSerializedMlirToXlaHlo( + invalid_mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), "Invalid argument: could not parse MLIR module: error: " "custom op 'totally' is unknown\n"); } -TEST(CompileSerializedMlirToXlaHloTest, Success) { - string mlir_module = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor - return %0 : tensor - } +constexpr llvm::StringRef kBinaryAddModule = R"( + module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor + return %0 : tensor } - )"; + } +)"; +TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { std::vector arg_shapes(2, TensorShape()); XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + kBinaryAddModule, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { EXPECT_EQ(expected_hlo_module_string, status_or_hlo_module.ValueOrDie()->ToString()); - // Expect an iota like input mapping. + // Expect an in order input mapping. EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); // Expect a single tuple-shape, containing two F32 scalars. @@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { EXPECT_TRUE(compilation_result.resource_updates.empty()); } +TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { + std::vector arg_shapes(2, TensorShape()); + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + kBinaryAddModule, arg_shapes, + /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); + ASSERT_TRUE(s.ok()); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + string expected_hlo_module_string = R"(HloModule main.5 + +ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) + ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); + + // Expect an in order input mapping. + EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); + + // Expect two inputs, each containing a F32 scalar. + EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2); + xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape); + EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape); + + // Expect output shape is a tuple shape containing a single F32 Scalar type. + const xla::Shape output_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); + const xla::Shape tuple_output_shape = + xla::ShapeUtil::MakeTupleShape({output_shape}); + EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); + + // Expect exactly 1 OutputDescription. + EXPECT_EQ(compilation_result.outputs.size(), 1); + const XlaCompiler::OutputDescription& output_desc = + compilation_result.outputs.front(); + EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); + EXPECT_EQ(output_desc.shape, TensorShape()); + EXPECT_FALSE(output_desc.is_constant); + EXPECT_FALSE(output_desc.is_tensor_list); + + // Expect no resource updates from computation. + EXPECT_TRUE(compilation_result.resource_updates.empty()); +} + // Tests that foldable ops are constant-folded to enable legalization of ops // that require compile time constant operand. TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { @@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 7b0cbe6d5b5..84a8969a486 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -57,6 +57,18 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { case DT_INT64: *type = builder.getIntegerType(64); return Status::OK(); + case DT_UINT8: + *type = builder.getIntegerType(8, /*isSigned=*/false); + return Status::OK(); + case DT_UINT16: + *type = builder.getIntegerType(16, /*isSigned=*/false); + return Status::OK(); + case DT_UINT32: + *type = builder.getIntegerType(32, /*isSigned=*/false); + return Status::OK(); + case DT_UINT64: + *type = builder.getIntegerType(64, /*isSigned=*/false); + return Status::OK(); case DT_BFLOAT16: *type = builder.getBF16Type(); return Status::OK(); @@ -99,16 +111,16 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { *dtype = DT_BOOL; return Status::OK(); case 8: - *dtype = DT_INT8; + *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; return Status::OK(); case 16: - *dtype = DT_INT16; + *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; return Status::OK(); case 32: - *dtype = DT_INT32; + *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; return Status::OK(); case 64: - *dtype = DT_INT64; + *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; return Status::OK(); default: return errors::Unimplemented( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index e983f3e9c0c..9561d0a2f93 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -20,7 +20,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project @@ -33,50 +35,124 @@ namespace tensorflow { constexpr char kDevicesAttr[] = "tf.devices"; -void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) { - if (!device_set) return; +namespace { - // Collect devices as strings in TensorFlow device name form. - llvm::SmallVector devices; - devices.reserve(device_set->devices().size()); - for (Device* device : device_set->devices()) - devices.push_back( - DeviceNameUtils::ParsedNameToString(device->parsed_name())); +// Parse GPU compute capability from physical device description. If compute +// capability is not found in device description, return an empty dictionary +// attribute. +mlir::DictionaryAttr ParseGpuDeviceMetadata(const Device& device, + mlir::Builder* builder) { + // Parse GPU device compute capability from physical device description. + static auto* r = new llvm::Regex("compute capability: ([0-9]+)\\.([0-9]+)"); - llvm::SmallVector device_refs(devices.begin(), - devices.end()); - mlir::Builder builder(op->getContext()); - op->setAttr(kDevicesAttr, builder.getStrArrayAttr(device_refs)); + llvm::SmallVector cc; + if (r->match(device.attributes().physical_device_desc(), &cc)) { + return mlir::TF::GpuDeviceMetadata::get( + builder->getI32IntegerAttr(std::stoi(cc[1].str())), + builder->getI32IntegerAttr(std::stoi(cc[2].str())), + builder->getContext()); + } + + return builder->getDictionaryAttr({}); } -mlir::LogicalResult GetDevicesFromOp( - mlir::Operation* op, - llvm::SmallVectorImpl* devices) { - auto devices_attr = op->getAttr(kDevicesAttr); - if (!devices_attr) return mlir::success(); +// Get devices from an array of string attributes. +// TODO(ezhulenev): Update all tests to use dictionary attribute for +// `tf.devices` and remove this function. +mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, + mlir::ArrayAttr array_attr, + mlir::TF::RuntimeDevices* devices) { + DeviceNameUtils::ParsedName device; - auto array_attr = devices_attr.dyn_cast(); - if (!array_attr) - return op->emitOpError( - llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr)); + for (auto& kv : llvm::enumerate(array_attr)) { + const int idx = kv.index(); - devices->resize(array_attr.size()); - for (auto attr_and_idx : llvm::enumerate(array_attr)) { - const int idx = attr_and_idx.index(); - auto string_attr = attr_and_idx.value().dyn_cast(); + auto string_attr = kv.value().dyn_cast(); if (!string_attr) return op->emitOpError(llvm::formatv( "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx)); - if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(), - &(*devices)[idx])) + if (DeviceNameUtils::ParseFullName(string_attr.getValue().str(), &device)) { + devices->AddDevice(device); + } else { return op->emitOpError( - llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', " - "not a valid device", - kDevicesAttr, idx, string_attr.getValue())); + llvm::formatv("bad '{0}' attribute, '{1}', not a valid device", + kDevicesAttr, string_attr.getValue())); + } } return mlir::success(); } +// Get devices from a dictionary attribute. +mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, + mlir::DictionaryAttr dict_attr, + mlir::TF::RuntimeDevices* devices) { + DeviceNameUtils::ParsedName device; + + // Parse device names and metadata from dictionary attribute. + for (auto& kv : dict_attr) { + const mlir::Identifier name = kv.first; + const mlir::Attribute attr = kv.second; + + if (!DeviceNameUtils::ParseFullName(name.str(), &device)) + return op->emitOpError( + llvm::formatv("bad '{0}' attribute, '{1}', not a valid device", + kDevicesAttr, name.strref())); + + if (auto gpu_metadata = attr.dyn_cast()) { + devices->AddGpuDevice(device, gpu_metadata); + } else { + devices->AddDevice(device); + } + } + + return mlir::success(); +} + +} // namespace + +void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) { + if (!device_set) return; + + mlir::MLIRContext* ctx = op->getContext(); + mlir::Builder builder(ctx); + + // Collect devices with attached metadata. + llvm::SmallVector devices; + devices.reserve(device_set->devices().size()); + + // For device that do not have any metadata, or if we failed to parse metadata + // from the DeviceSet, we add empty dictionary to the `tf.devices` attribute. + for (Device* device : device_set->devices()) { + string name = DeviceNameUtils::ParsedNameToString(device->parsed_name()); + + if (device->device_type() == DEVICE_GPU) { + auto metadata = ParseGpuDeviceMetadata(*device, &builder); + devices.push_back(builder.getNamedAttr(name, metadata)); + } else { + auto metadata = builder.getDictionaryAttr({}); + devices.push_back(builder.getNamedAttr(name, metadata)); + } + } + + op->setAttr(kDevicesAttr, builder.getDictionaryAttr(devices)); +} + +mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, + mlir::TF::RuntimeDevices* devices) { + auto devices_attr = op->getAttr(kDevicesAttr); + if (!devices_attr) return mlir::success(); + + if (auto array_attr = devices_attr.dyn_cast()) { + return GetDevicesFromOp(op, array_attr, devices); + + } else if (auto dict_attr = devices_attr.dyn_cast()) { + return GetDevicesFromOp(op, dict_attr, devices); + } + + return op->emitOpError( + llvm::formatv("unsupported '{0}' attribute", kDevicesAttr)); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h index 73ae18d2487..1cbf0517554 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -19,22 +19,27 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { + // Collects all devices known to the system by name and adds them as a -// `tf.devices` array attribute of string attributes to an op. Device names -// added are in the following form: +// `tf.devices` dictionary attribute with a full device name as a key, and +// device metadata as a value. +// +// Device names added in full parsed device form: // /job:/replica:/task:/device:: +// +// Supported device metadata types: +// (1) GpuDeviceMetadata: GPU device compute capability. void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set); -// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices` -// attribute. A failure will be returned if the attribute is not an -// ArrayAttr or the devices are invalid. -mlir::LogicalResult GetDevicesFromOp( - mlir::Operation* op, - llvm::SmallVectorImpl* devices); +// Collects devices information from an op `tf.devices` attributes. Returns +// failure if can't parse device metadata from the attribute. +mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, + mlir::TF::RuntimeDevices* devices); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index cb25e000f7a..25e55e23c1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include -#include #include #include @@ -46,13 +45,15 @@ class FakeDevice : public Device { Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } - static std::unique_ptr Make(const string& name) { + static std::unique_ptr Make(const string& name, + const string& desc = "") { DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParseFullName(name, &parsed_name); DeviceAttributes device_attributes; device_attributes.set_name(name); device_attributes.set_device_type(parsed_name.type); + device_attributes.set_physical_device_desc(desc); return std::make_unique(device_attributes); } }; @@ -62,26 +63,40 @@ TEST(DeviceUtilTest, AddDeviceToOp) { mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - DeviceSet device_set; - llvm::SmallVector, 2> devices; - devices.push_back( - FakeDevice::Make("/job:worker/replica:0/task:0/device:CPU:0")); - devices.push_back( - FakeDevice::Make("/job:worker/replica:1/task:2/device:GPU:3")); - for (auto& device : devices) device_set.AddDevice(device.get()); + const std::string cpu0 = "/job:worker/replica:0/task:0/device:CPU:0"; + const std::string gpu0 = "/job:worker/replica:1/task:2/device:GPU:0"; + const std::string gpu1 = "/job:worker/replica:1/task:2/device:GPU:1"; + llvm::SmallVector, 2> devices; + devices.push_back(FakeDevice::Make(cpu0)); + devices.push_back(FakeDevice::Make(gpu0, "compute capability: 7.0")); + devices.push_back(FakeDevice::Make(gpu1)); + + DeviceSet device_set; + for (auto& device : devices) device_set.AddDevice(device.get()); AddDevicesToOp(*module_ref, &device_set); - auto devices_attr = module_ref->getAttrOfType("tf.devices"); + + auto devices_attr = + module_ref->getAttrOfType("tf.devices"); ASSERT_NE(devices_attr, nullptr); - ASSERT_EQ(devices_attr.size(), 2); - auto device_attr_0 = devices_attr.getValue()[0].dyn_cast(); - ASSERT_NE(device_attr_0, nullptr); - EXPECT_EQ(device_attr_0.getValue(), - "/job:worker/replica:0/task:0/device:CPU:0"); - auto device_attr_1 = devices_attr.getValue()[1].dyn_cast(); - ASSERT_NE(device_attr_1, nullptr); - EXPECT_EQ(device_attr_1.getValue(), - "/job:worker/replica:1/task:2/device:GPU:3"); + ASSERT_EQ(devices_attr.size(), 3); + + // CPU device added with an empty metadata. + auto device_meta_0 = devices_attr.get(cpu0).dyn_cast(); + ASSERT_NE(device_meta_0, nullptr); + ASSERT_EQ(device_meta_0.size(), 0); + + // GPU device successfully parsed compute capability from description. + auto device_meta_1 = + devices_attr.get(gpu0).dyn_cast(); + ASSERT_NE(device_meta_1, nullptr); + ASSERT_EQ(device_meta_1.cc_major().getInt(), 7); + ASSERT_EQ(device_meta_1.cc_minor().getInt(), 0); + + // If description is empty GPU devices added with an empty metadata. + auto device_meta_2 = devices_attr.get(gpu1).dyn_cast(); + ASSERT_NE(device_meta_2, nullptr); + ASSERT_EQ(device_meta_2.size(), 0); } TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) { @@ -98,7 +113,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) { mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - llvm::SmallVector devices; + mlir::TF::RuntimeDevices devices; EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices))); } @@ -109,7 +124,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeType) { mlir::Builder builder(*module_ref); module_ref->setAttr("tf.devices", builder.getBoolAttr(false)); - llvm::SmallVector devices; + mlir::TF::RuntimeDevices devices; EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices))); } @@ -120,7 +135,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeArraySubtype) { mlir::Builder builder(*module_ref); module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8})); - llvm::SmallVector devices; + mlir::TF::RuntimeDevices devices; EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices))); } @@ -129,9 +144,11 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesInDevicesAttribute) { mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::Builder builder(*module_ref); - module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"})); + module_ref->setAttr("tf.devices", + builder.getDictionaryAttr(builder.getNamedAttr( + "bad_device", builder.getDictionaryAttr({})))); - llvm::SmallVector devices; + mlir::TF::RuntimeDevices devices; EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices))); } @@ -140,16 +157,53 @@ TEST(DeviceUtilTest, GetDevicesFromOpValidDeviceInDevicesAttribute) { mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::Builder builder(*module_ref); - module_ref->setAttr( - "tf.devices", - builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"})); - llvm::SmallVector devices; + auto device_dict = builder.getDictionaryAttr( + {builder.getNamedAttr("/job:worker/replica:0/task:0/device:CPU:0", + builder.getDictionaryAttr({}))}); + module_ref->setAttr("tf.devices", device_dict); + + mlir::TF::RuntimeDevices devices; EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices))); - ASSERT_EQ(devices.size(), 1); - EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]), + + ASSERT_EQ(devices.NumDevices(), 1); + ASSERT_EQ(devices.device_names().size(), 1); + ASSERT_EQ(DeviceNameUtils::ParsedNameToString(devices.device_names()[0]), "/job:worker/replica:0/task:0/device:CPU:0"); } +TEST(DeviceUtilTest, GetGpuDeviceMetadata) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + mlir::Builder builder(*module_ref); + + const std::string gpu0 = "/job:worker/replica:0/task:0/device:GPU:0"; + const std::string gpu1 = "/job:worker/replica:0/task:0/device:GPU:1"; + + llvm::SmallVector metadata; + metadata.push_back(builder.getNamedAttr( + gpu0, mlir::TF::GpuDeviceMetadata::get(builder.getI32IntegerAttr(1), + builder.getI32IntegerAttr(2), + module_ref->getContext()))); + + module_ref->setAttr("tf.devices", builder.getDictionaryAttr(metadata)); + + mlir::TF::RuntimeDevices devices; + EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices))); + + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseFullName(gpu0, &parsed_name); + auto meta_0 = devices.GetGpuDeviceMetadata(parsed_name); + ASSERT_TRUE(meta_0.hasValue()); + ASSERT_EQ(meta_0->cc_major().getInt(), 1); + ASSERT_EQ(meta_0->cc_minor().getInt(), 2); + + DeviceNameUtils::ParseFullName(gpu1, &parsed_name); + auto meta_1 = devices.GetGpuDeviceMetadata(parsed_name); + ASSERT_FALSE(meta_1.hasValue()); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 1b8ae8403bf..36a59d12060 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -170,4 +170,16 @@ std::string GetDumpDirFromEnvVar() { return result; } +std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, + llvm::StringRef dirname) { + std::unique_ptr os; + std::string filepath; + Status result = CreateFileForDumping(name, &os, &filepath, dirname); + if (!result.ok()) return result.error_message(); + + (*os) << content; + LOG(INFO) << "Outputted requested string to '" << filepath << "'"; + return filepath; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 14c0d1f0b6e..79c4961273a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -56,6 +56,14 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, // cannot be determined and generates a warning message. std::string GetDumpDirFromEnvVar(); +// Dumps a raw string to a file and returns the file name used. +// +// This will create a file name via prefixing `name` with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and +// suffixing `name` with ".mlir". +std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, + llvm::StringRef dirname = ""); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 947a0ef0af3..69e90de3cb6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -69,5 +69,20 @@ TEST(DumpMlirModuleTest, Valid) { EXPECT_EQ(file_txt_module, expected_txt_module); } +TEST(DumpRawStringToFileTest, Valid) { + llvm::StringRef example = "module {\n}"; + setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); + + std::string filepath = DumpRawStringToFile("example", example); + ASSERT_NE(filepath, "(TF_DUMP_GRAPH_PREFIX not specified)"); + ASSERT_NE(filepath, "LOG(INFO)"); + ASSERT_NE(filepath, "(unavailable)"); + + Env* env = Env::Default(); + std::string file_txt_module; + TF_ASSERT_OK(ReadFileToString(env, filepath, &file_txt_module)); + EXPECT_EQ(file_txt_module, example); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index 3f4947bec23..61214108957 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -58,7 +58,8 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { emitError(loc) << "Second diagnostic message reported"; return tensorflow::errors::Internal("Passed in error"); }; - Status s = StatusScopedDiagnosticHandler(&context).Combine(function()); + StatusScopedDiagnosticHandler ssdh(&context); + Status s = ssdh.Combine(function()); ASSERT_TRUE(tensorflow::errors::IsInternal(s)); EXPECT_THAT(s.error_message(), HasSubstr("Passed in error")); EXPECT_THAT(s.error_message(), HasSubstr("Diagnostic message reported")); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 33a09a6ddfb..6cf2781e48d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" -#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/types.h" @@ -39,9 +39,9 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { -// Device coordinates are defined as (x, y, core), thus resulting in a rank 3 +// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // topology. -constexpr int kTPUTopologyRank = 3; +constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; @@ -209,43 +209,43 @@ struct TaskAndDevice { }; // Checks if device coordinate is outside of topology mesh shape bounds. -bool DeviceCoordinateOutOfBound(int x, int y, int core, int bound_x, - int bound_y, int bound_core) { - return x < 0 || x >= bound_x || y < 0 || y >= bound_y || core < 0 || - core >= bound_core; +bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x, + int bound_y, int bound_z, int bound_core) { + return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 || + z >= bound_z || core < 0 || core >= bound_core; } // Creates error message for an out of bound device coordinate. Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y, - int core, int bound_x, int bound_y, - int bound_core) { - return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", core, - ") in '", attribute, + int z, int core, int bound_x, int bound_y, + int bound_z, int bound_core) { + return errors::InvalidArgument("device coordinate (", x, ", ", y, ", ", z, + ", ", core, ") in '", attribute, "' is outside of mesh shape (", bound_x, ", ", - bound_y, ", ", bound_core, ")"); + bound_y, ", ", bound_z, ", ", bound_core, ")"); } // Creates error message for a duplicate device coordinate. Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y, - int core) { + int z, int core) { return errors::InvalidArgument("'", attribute, "' has duplicate device coordinate (", x, ", ", - y, ", ", core, ")"); + y, ", ", z, ", ", core, ")"); } // Parses and validates topology (serialized string of TopologyProto), and maps -// device coordinate (x, y, core) to task and device (of available TPUs). +// device coordinate (x, y, z, core) to task and device (of available TPUs). // Topology attribute device coordinates are ordered by task then device (major // to minor). // // A valid TopologyProto must have: -// - a valid mesh shape (rank 3 with positive dimensions) +// - a valid mesh shape (rank 4 with positive dimensions) // - `num_tasks` and `num_tpu_devices_per_task` must match the number of // available TPU hosts and devices per host // - device coordinates within the mesh shape // - no duplicate device coordinates // - number of device coordinates (in tuple 3) match number of availabe TPUs -StatusOr> ParseTopologyAttr( +StatusOr> ParseTopologyAttr( llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) { tpu::TopologyProto topology_proto; if (!topology_proto.ParseFromString(topology_attr.str())) @@ -288,22 +288,25 @@ StatusOr> ParseTopologyAttr( const int bound_x = topology_proto.mesh_shape(0); const int bound_y = topology_proto.mesh_shape(1); - const int bound_core = topology_proto.mesh_shape(2); + const int bound_z = topology_proto.mesh_shape(2); + const int bound_core = topology_proto.mesh_shape(3); - xla::Array3D topology(bound_x, bound_y, bound_core, {}); + xla::Array4D topology(bound_x, bound_y, bound_z, bound_core); int pos = 0; for (int task = 0; task < num_tasks; ++task) { for (int device = 0; device < num_tpus_per_task; ++device) { int x = topology_proto.device_coordinates(pos++); int y = topology_proto.device_coordinates(pos++); + int z = topology_proto.device_coordinates(pos++); int core = topology_proto.device_coordinates(pos++); - if (DeviceCoordinateOutOfBound(x, y, core, bound_x, bound_y, bound_core)) - return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, core, bound_x, - bound_y, bound_core); + if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z, + bound_core)) + return DeviceCoordinateErrorMsg(kTopologyAttr, x, y, z, core, bound_x, + bound_y, bound_z, bound_core); - auto& task_and_device = topology(x, y, core); + auto& task_and_device = topology(x, y, z, core); if (task_and_device.task != -1) - return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, core); + return DuplicateCoordinateErrorMsg(kTopologyAttr, x, y, z, core); task_and_device = {task, device}; } @@ -346,16 +349,18 @@ GetGeneralTPUExecutionDeviceAssignment( const int bound_x = topology.n1(); const int bound_y = topology.n2(); - const int bound_core = topology.n3(); + const int bound_z = topology.n3(); + const int bound_core = topology.n4(); // TPU XLA device ID is determined by its device coordinate, from major to - // minor coordinates (y, x, core). - auto location_to_id = [&](int x, int y, int core) { - return x * bound_core + y * bound_x * bound_core + core; + // minor coordinates (z, y, x, core). + auto location_to_id = [&](int x, int y, int z, int core) { + return (x + bound_x * (y + bound_y * z)) * bound_core + core; }; std::vector used_device_ids( - location_to_id(bound_x - 1, bound_y - 1, bound_core - 1), false); + location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), + false); ExecutionDevices execution_devices( num_replicas, llvm::SmallVector(num_cores_per_replica, "")); @@ -366,22 +371,25 @@ GetGeneralTPUExecutionDeviceAssignment( ++logical_core) { int x = device_assignment_attr[pos++]; int y = device_assignment_attr[pos++]; + int z = device_assignment_attr[pos++]; int core = device_assignment_attr[pos++]; - if (DeviceCoordinateOutOfBound(x, y, core, bound_x, bound_y, bound_core)) - return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, core, - bound_x, bound_y, bound_core); + if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z, + bound_core)) + return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core, + bound_x, bound_y, bound_z, bound_core); - TaskAndDevice task_and_device = topology(x, y, core); + TaskAndDevice task_and_device = topology(x, y, z, core); const int task = task_and_device.task; const int device = task_and_device.device; if (task == -1 || device == -1) return errors::InvalidArgument( "no TPU device found for '", kDeviceAssignmentAttr, - "' device coordinate (", x, ", ", y, ", ", core, ")"); + "' device coordinate (", x, ", ", y, ", ", z, ", ", core, ")"); - const int device_id = location_to_id(x, y, core); + const int device_id = location_to_id(x, y, z, core); if (used_device_ids[device_id]) - return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, core); + return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, + core); used_device_ids[device_id] = true; device_assignment(replica, logical_core) = device_id; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index de7009b495f..87319f2adeb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -129,6 +129,7 @@ std::string TopologyWithDeviceCoordinates( topology_proto.add_mesh_shape(2); topology_proto.add_mesh_shape(1); topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(1); topology_proto.set_num_tasks(2); topology_proto.set_num_tpu_devices_per_task(1); for (int device_coordinate : device_coordinates) @@ -155,89 +156,100 @@ INSTANTIATE_TEST_SUITE_P( "failed to parse 'topology' attribute to TopologyProto"), std::make_tuple(4, 2, TopologyWithMeshShape({0}), std::vector(), - "'topology' 'mesh_shape' must be rank 3, got rank 1"), + "'topology' 'mesh_shape' must be rank 4, got rank 1"), std::make_tuple( - 2, 1, TopologyWithMeshShape({2, 0, 2}), std::vector(), + 2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector(), "'topology' 'mesh_shape' dimension 1 must be positive, got 0"), - std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1}, 1, 1), + std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1), std::vector(), "number of tasks from available TPU devices must be " "'num_tasks' in 'topology' (1), got 2"), - std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1}, 2, 2), + std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2), std::vector(), "number of TPU devices available per task must be " "'num_tpu_devices_per_task' in 'topology' (2), got 1"), std::make_tuple( 2, 1, TopologyWithDeviceCoordinates({}), std::vector(), "length of 'device_coordinates' in 'topology' must be 'num_tasks' " - "* 'num_tpus_per_task' * 3 (2 * 1 * 3), got 0"), - std::make_tuple(2, 1, - TopologyWithDeviceCoordinates({-1, 0, 0, 1, 0, 0}), - std::vector(), - "device coordinate (-1, 0, 0) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 1, 0, 0}), - std::vector(), - "device coordinate (2, 0, 0) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, - TopologyWithDeviceCoordinates({0, -1, 0, 1, 0, 0}), - std::vector(), - "device coordinate (0, -1, 0) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 1, 0, 0}), - std::vector(), - "device coordinate (0, 1, 0) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, - TopologyWithDeviceCoordinates({0, 0, -1, 1, 0, 0}), - std::vector(), - "device coordinate (0, 0, -1) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 1, 1, 0, 0}), - std::vector(), - "device coordinate (0, 0, 1) in 'topology' is outside " - "of mesh shape (2, 1, 1)"), + "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"), std::make_tuple( - 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0}), + 2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}), std::vector(), - "'topology' has duplicate device coordinate (0, 0, 0)"))); + "device coordinate (-1, 0, 0, 0) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}), + std::vector(), + "device coordinate (2, 0, 0, 0) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}), + std::vector(), + "device coordinate (0, -1, 0, 0) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}), + std::vector(), + "device coordinate (0, 1, 0, 0) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}), + std::vector(), + "device coordinate (0, 0, 0, -1) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}), + std::vector(), + "device coordinate (0, 0, 0, 1) in 'topology' is outside " + "of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}), + std::vector(), + "'topology' has duplicate device coordinate (0, 0, 0, 0)"))); INSTANTIATE_TEST_SUITE_P( BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest, ::testing::Values( - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), + std::make_tuple(2, 1, + TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), std::vector(), "length of 'device_assignment' must be 'num_replicas' " - "* 'num_cores_per_replica' * 3 (2 * 1 * 3), got 0"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{-1, 0, 0, 0, 0, 0}, - "device coordinate (-1, 0, 0) in 'device_assignment' " - "is outside of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{2, 0, 0, 0, 0, 0}, - "device coordinate (2, 0, 0) in 'device_assignment' is " - "outside of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{0, -1, 0, 0, 0, 0}, - "device coordinate (0, -1, 0) in 'device_assignment' " - "is outside of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{0, 1, 0, 0, 0, 0}, - "device coordinate (0, 1, 0) in 'device_assignment' is " - "outside of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{0, 0, -1, 0, 0, 0}, - "device coordinate (0, 0, -1) in 'device_assignment' " - "is outside of mesh shape (2, 1, 1)"), - std::make_tuple(2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{0, 0, 1, 0, 0, 0}, - "device coordinate (0, 0, 1) in 'device_assignment' is " - "outside of mesh shape (2, 1, 1)"), + "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"), std::make_tuple( - 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 0, 0}), - std::vector{0, 0, 0, 0, 0, 0}, - "'device_assignment' has duplicate device coordinate (0, 0, 0)"))); + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{-1, 0, 0, 0, 0, 0, 0, 0}, + "device coordinate (-1, 0, 0, 0) in 'device_assignment' " + "is outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{2, 0, 0, 0, 0, 0, 0, 0}, + "device coordinate (2, 0, 0, 0) in 'device_assignment' is " + "outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{0, -1, 0, 0, 0, 0, 0, 0}, + "device coordinate (0, -1, 0, 0) in 'device_assignment' " + "is outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{0, 1, 0, 0, 0, 0, 0, 0}, + "device coordinate (0, 1, 0, 0) in 'device_assignment' is " + "outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{0, 0, 0, -1, 0, 0, 0, 0}, + "device coordinate (0, 0, 0, -1) in 'device_assignment' " + "is outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple( + 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{0, 0, 0, 1, 0, 0, 0, 0}, + "device coordinate (0, 0, 0, 1) in 'device_assignment' is " + "outside of mesh shape (2, 1, 1, 1)"), + std::make_tuple(2, 1, + TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}), + std::vector{0, 0, 0, 0, 0, 0, 0, 0}, + "'device_assignment' has duplicate device coordinate " + "(0, 0, 0, 0)"))); std::vector MakeDeviceSet(int num_tasks, int num_devices_per_task) { @@ -270,15 +282,17 @@ TEST(TPURewriteDeviceUtilTest, topology_proto.add_mesh_shape(2); topology_proto.add_mesh_shape(1); topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(1); topology_proto.set_num_tasks(1); topology_proto.set_num_tpu_devices_per_task(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); } std::string topology_attr = topology_proto.SerializeAsString(); - std::vector device_assignment_attr{1, 0, 0}; + std::vector device_assignment_attr{1, 0, 0, 0}; llvm::SmallVector devices; std::vector device_names = @@ -292,7 +306,7 @@ TEST(TPURewriteDeviceUtilTest, ASSERT_FALSE(status_or.ok()); EXPECT_EQ(status_or.status().error_message(), "no TPU device found for 'device_assignment' device coordinate (1, " - "0, 0)"); + "0, 0, 0)"); } TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) { @@ -342,6 +356,7 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { { topology_proto.add_mesh_shape(2); topology_proto.add_mesh_shape(2); + topology_proto.add_mesh_shape(1); topology_proto.add_mesh_shape(2); topology_proto.set_num_tasks(2); topology_proto.set_num_tpu_devices_per_task(4); @@ -349,31 +364,40 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); - topology_proto.add_device_coordinates(1); - topology_proto.add_device_coordinates(0); - topology_proto.add_device_coordinates(1); - topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); } std::string topology_attr = topology_proto.SerializeAsString(); - std::vector device_assignment_attr{ - 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1}; + std::vector device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, + 0, 1, 1, 1, 0, 0, 1, 1, 0, 1}; llvm::SmallVector devices; std::vector device_names = @@ -433,11 +457,12 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { EXPECT_EQ(computation_device_1.replica_device_ids(3), 7); } -TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x3) { +TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { tpu::TopologyProto topology_proto; { topology_proto.add_mesh_shape(1); topology_proto.add_mesh_shape(2); + topology_proto.add_mesh_shape(1); topology_proto.add_mesh_shape(3); topology_proto.set_num_tasks(3); topology_proto.set_num_tpu_devices_per_task(2); @@ -445,25 +470,31 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x3) { topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(2); topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); topology_proto.add_device_coordinates(2); } std::string topology_attr = topology_proto.SerializeAsString(); - std::vector device_assignment_attr{0, 0, 1, 0, 1, 1, 0, 0, 2, - 0, 1, 2, 0, 0, 0, 0, 1, 0}; + std::vector device_assignment_attr{ + 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0}; llvm::SmallVector devices; std::vector device_names = diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index dff47861419..6aeead516e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -41,20 +41,22 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, auto version_attr = module.getAttrOfType("tf.versions"); if (!version_attr) return mlir::failure(); - auto producer = version_attr.get("producer").dyn_cast(); + auto producer = + version_attr.get("producer").dyn_cast_or_null(); if (!producer) return mlir::failure(); versions->set_producer(producer.getInt()); auto min_consumer = - version_attr.get("min_consumer").dyn_cast(); - if (!min_consumer) return mlir::failure(); - versions->set_min_consumer(min_consumer.getInt()); + version_attr.get("min_consumer").dyn_cast_or_null(); + if (min_consumer) versions->set_min_consumer(min_consumer.getInt()); auto bad_consumers = - version_attr.get("bad_consumers").dyn_cast(); - if (!bad_consumers) return mlir::failure(); + version_attr.get("bad_consumers").dyn_cast_or_null(); + if (!bad_consumers) return mlir::success(); + for (auto bad_consumer : bad_consumers) { - auto bad_consumer_int_attr = bad_consumer.dyn_cast(); + auto bad_consumer_int_attr = + bad_consumer.dyn_cast_or_null(); if (!bad_consumer_int_attr) return mlir::failure(); versions->mutable_bad_consumers()->Add(bad_consumer_int_attr.getInt()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc new file mode 100644 index 00000000000..bbe91054b3b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -0,0 +1,204 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace tensorflow { + +const char* const kXlaShardingAttrName = "_XlaSharding"; +const char* const kInputShardingAttr = "input_sharding_configuration"; +const char* const kOutputShardingAttr = "output_sharding_configuration"; + +llvm::Optional ParseShardingAttribute( + mlir::Operation* operation) { + const auto& sharding_attr = + operation->getAttrOfType(kXlaShardingAttrName); + if (!sharding_attr) return llvm::Optional(); + return sharding_attr.getValue(); +} + +llvm::SmallVector, 4> +ExtractInputsForLogicalDevices(int num_logical_cores, + mlir::tf_device::LaunchFuncOp launch_func) { + // Initialize the input list for each logical devices. + llvm::SmallVector, 4> input_list; + input_list.reserve(num_logical_cores); + for (int i = 0; i < num_logical_cores; ++i) + input_list.emplace_back(llvm::SmallVector()); + + llvm::SmallVector launch_func_inputs( + launch_func.getOperands()); + auto sharding_attrs = + launch_func.getOperation()->getAttrOfType( + kInputShardingAttr); + // If sharding attribute does not exist, then all inputs are placed on 0th + // logical core by default. + if (!sharding_attrs) { + input_list[0] = launch_func_inputs; + return input_list; + } + + // Enumerate sharding configuration for each inputs. If input has replicate + // sharding then all logical devices take the value as input. If input has + // maximal sharding then only the specified logical device take the value as + // the input. + for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) { + const auto& sharding_attr = sharding_attr_and_index.value(); + const auto input_index = sharding_attr_and_index.index(); + const auto& input_value = launch_func_inputs[input_index]; + + xla::OpSharding sharding; + sharding.ParseFromString( + sharding_attr.cast().getValue().str()); + + const auto input_sharing_type = sharding.type(); + if (input_sharing_type == xla::OpSharding::OTHER) + launch_func.emitError( + "tiled inputs are not yet supported for model parallelism"); + + if (input_sharing_type == xla::OpSharding::REPLICATED) { + for (auto inputs : input_list) inputs.emplace_back(input_value); + } else { + assert(input_sharing_type == xla::OpSharding::MAXIMAL); + const int logical_device_id = sharding.tile_assignment_devices(0); + input_list[logical_device_id].emplace_back(input_value); + } + } + return input_list; +} + +mlir::LogicalResult ParseAndValidateOutputSharding( + mlir::tf_device::LaunchFuncOp launch_func, + mlir::SmallVector* output_sharding_list) { + output_sharding_list->reserve(launch_func.getNumResults()); + + const auto output_sharding_attrs = + launch_func.getOperation()->getAttrOfType( + kOutputShardingAttr); + if (!output_sharding_attrs) + return launch_func.emitError( + "output_sharding_configuration missing from launch func"); + + if (output_sharding_attrs.size() != launch_func.getNumResults()) + return launch_func.emitError("incorrect number of output sharding"); + + for (auto output_sharding_and_index : + llvm::enumerate(output_sharding_attrs)) { + const auto& output_sharding = output_sharding_and_index.value(); + const int sharding_index = output_sharding_and_index.index(); + if (!output_sharding.isa()) + return launch_func.emitError(llvm::formatv( + "non-string output sharding at index {0}", sharding_index)); + + xla::OpSharding sharding; + if (!sharding.ParseFromString( + output_sharding.cast().getValue().str())) + return launch_func.emitError("incorrect sharding format for outputs"); + + const auto output_sharing_type = sharding.type(); + if (output_sharing_type == xla::OpSharding::OTHER) + return launch_func.emitError( + "tiled outputs are not yet supported for model parallelism"); + + output_sharding_list->emplace_back(std::move(sharding)); + } + return mlir::success(); +} + +namespace { + +bool IsAssignedToLogicalDevice(const int core_id, + const xla::OpSharding& sharding) { + return sharding.type() == xla::OpSharding::MAXIMAL && + sharding.tile_assignment_devices(0) == core_id; +} + +// Returns the index of the return value of region in +// `tf_device.parallel_execute` that represents launch func output at +// index |launch_func_output_index|. Regions of parallel_execute may +// have different return values depending on outside sharding +// configuration. +int MapLaunchOutputIndexWithRegionOutputIndex( + llvm::ArrayRef output_sharding_config, const int core_id, + const int launch_func_output_index) { + int region_output_index = 0; + for (int output_index = 0; output_index < launch_func_output_index; + ++output_index) { + const auto& sharding = output_sharding_config[output_index]; + if (sharding.type() == xla::OpSharding::REPLICATED || + IsAssignedToLogicalDevice(core_id, sharding)) + region_output_index++; + } + + return region_output_index; +} + +} // namespace + +mlir::SmallVector GetOutputTypesForLogicalDeviceComputation( + const int logical_device_id, + llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func) { + mlir::SmallVector output_types; + output_types.reserve(launch_func.getNumResults()); + + for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + const auto output_index = result_and_index.index(); + const auto& output_sharding = output_sharding_config[output_index]; + const auto output_sharding_type = output_sharding.type(); + const auto& launch_func_output = result_and_index.value(); + + if (output_sharding_type == xla::OpSharding::REPLICATED || + IsAssignedToLogicalDevice(logical_device_id, output_sharding)) + output_types.emplace_back(launch_func_output.getType()); + } + + return output_types; +} + +void RemapOutputsFromLogicalDevices( + llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ParallelExecuteOp parallel_execute) { + for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + const auto output_index = result_and_index.index(); + const auto& launch_func_output = result_and_index.value(); + const auto& output_sharding = output_sharding_config[output_index]; + const auto output_sharing_type = output_sharding.type(); + + int logical_device_id = 0; + if (output_sharing_type == xla::OpSharding::MAXIMAL) + logical_device_id = output_sharding.tile_assignment_devices(0); + + // For maximal sharding configuration, correctly remap outputs from + // parallel_execute region to users of the launch func. + const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( + output_sharding_config, logical_device_id, output_index); + + const auto output_from_logical_device = parallel_execute.GetRegionOutputs( + logical_device_id)[region_output_index]; + + launch_func_output.replaceAllUsesWith(output_from_logical_device); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h new file mode 100644 index 00000000000..4f548ca95aa --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace tensorflow { + +extern const char* const kXlaShardingAttrName; +extern const char* const kInputShardingAttr; +extern const char* const kOutputShardingAttr; + +// Parses "_XlaSharding" attribute from operation, if it exists. +llvm::Optional ParseShardingAttribute( + mlir::Operation* operation); + +// Parses "input_sharding_configuration" attribute and returns a list where +// i-th element is a list of mlir::Value's which represent inputs for the +// TPU computation correponding to i-th logical device. If the attribute +// does not exist, the all inputs are placed on logical core 0. +llvm::SmallVector, 4> +ExtractInputsForLogicalDevices(int num_logical_cores, + mlir::tf_device::LaunchFuncOp launch_func); + +// Extracts a list of OpSharding that represent output sharding configuration +// of `tf_device.launch`. +mlir::LogicalResult ParseAndValidateOutputSharding( + mlir::tf_device::LaunchFuncOp launch_func, + mlir::SmallVector* output_sharding_list); + +// Retrieves output types for TPUExecute op representing execution for provided +// logical device id. TPUExecute op for different logical device may have +// different outputs depending on the output sharding configuration. +mlir::SmallVector GetOutputTypesForLogicalDeviceComputation( + const int logical_device_id, + llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func); + +// Remaps outputs of `tf_device.parallel_execute` op that represent concurrent +// execution of the `tf_device.launch_func` with its users. +void RemapOutputsFromLogicalDevices( + llvm::ArrayRef output_sharding_config, + mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ParallelExecuteOp parallel_execute); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index f5fc56556ec..29f9ec7eb46 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -49,15 +49,17 @@ static llvm::cl::opt splitInputFile( llvm::cl::init(false)); // NOLINTNEXTLINE -static llvm::cl::opt import_saved_model( - "savedmodel-to-mlir", - llvm::cl::desc("Import a saved model to its MLIR representation"), +static llvm::cl::opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", + llvm::cl::desc( + "Import a saved model's object graph to its MLIR representation"), llvm::cl::value_desc("dir")); // NOLINTNEXTLINE -static llvm::cl::opt import_saved_model_v1( - "savedmodel-v1-to-mlir", - llvm::cl::desc("Import a saved model V1 to its MLIR representation"), +static llvm::cl::opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", + llvm::cl::desc( + "Import a saved model's SignatureDefs to to their MLIR representation"), llvm::cl::value_desc("dir")); // NOLINTNEXTLINE @@ -83,11 +85,12 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); - if (!import_saved_model && !import_saved_model_v1 && !requested_translation) { + if (!import_saved_model_object_graph && !import_saved_model_signature_defs && + !requested_translation) { llvm::errs() << "error: need to specify one translation to perform\n"; return 1; - } else if (import_saved_model && import_saved_model_v1 && - requested_translation) { + } else if (import_saved_model_object_graph && + import_saved_model_signature_defs && requested_translation) { llvm::errs() << "error: cannot specify more than one translation to perform\n"; return 1; @@ -100,26 +103,26 @@ int main(int argc, char** argv) { return 1; } - if (import_saved_model) { + if (import_saved_model_object_graph) { std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); std::vector exported_names = absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); mlir::MLIRContext context; - auto module = tensorflow::SavedModelToMlirImport( + auto module = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, absl::Span(exported_names), &context); if (!module) return 1; module->print(output->os()); - } else if (import_saved_model_v1) { + } else if (import_saved_model_signature_defs) { std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); mlir::MLIRContext context; - auto module = - tensorflow::SavedModelV1ToMlirImport(input_filename, tags, &context); + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, &context); if (!module) return 1; module->print(output->os()); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 5bf056af832..72126a7ef8f 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -15,11 +15,13 @@ package_group( "//learning/brain/experimental/swift_mlir/...", "//learning/brain/google/xla/kernels/...", "//learning/brain/swift/swift_mlir/...", + "//platforms/xla/...", "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/tf2xla/...", "//tensorflow/compiler/xla/...", "//third_party/iree/...", "//third_party/mlir_edge/...", + "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -28,11 +30,27 @@ exports_files(["ir/hlo_ops.td"]) filegroup( name = "hlo_ops_td_files", srcs = [ + "ir/hlo_client_ops.td", "ir/hlo_ops.td", "ir/hlo_ops_base.td", "ir/hlo_utils.td", "ir/lhlo_ops.td", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + ], +) + +gentbl( + name = "hlo_client_ops_inc_gen", + tbl_outs = [ + ("-gen-op-decls", "ir/hlo_client_ops.h.inc"), + ("-gen-op-defs", "ir/hlo_client_ops.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/hlo_client_ops.td", + td_srcs = [ + ":hlo_ops_td_files", ], ) @@ -47,7 +65,10 @@ gentbl( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/hlo_ops.td", td_includes = ["ir/hlo_utils.td"], - td_srcs = [":hlo_ops_td_files"], + td_srcs = [ + ":hlo_ops_td_files", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + ], ) gentbl( @@ -130,28 +151,62 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_legalize_tf_with_tf2xla", + srcs = [ + "transforms/legalize_tf_with_tf2xla.cc", + ], + deps = [ + ":hlo", + ":mlir_hlo_builder", + "//tensorflow/compiler/jit:xla_cpu_device", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/stream_executor:timer", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + cc_library( name = "map_xla_to_scalar_op", - srcs = [], hdrs = ["transforms/map_xla_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", + ":map_hlo_to_lhlo_op", "@llvm-project//llvm:support", "@llvm-project//mlir:StandardOps", ], ) cc_library( - name = "hlo_shape_derivation", - srcs = [], - hdrs = ["transforms/hlo_shape_derivation.h"], + name = "map_hlo_to_lhlo_op", + hdrs = ["transforms/map_hlo_to_lhlo_op.h"], deps = [ ":hlo", ":lhlo", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Transforms", ], ) @@ -173,6 +228,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lhlo_legalize_to_parallel_loops", + srcs = ["transforms/lhlo_legalize_to_parallel_loops.cc"], + deps = [ + ":lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LoopOps", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "xla_legalize_to_linalg", srcs = ["transforms/xla_legalize_to_linalg.cc"], @@ -226,13 +298,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lhlo_copy_removal", + srcs = ["transforms/lhlo_copy_removal.cc"], + deps = [ + ":lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + ], + alwayslink = 1, +) + cc_library( name = "hlo_legalize_to_lhlo", srcs = ["transforms/hlo_legalize_to_lhlo.cc"], deps = [ ":hlo", - ":hlo_shape_derivation", ":lhlo", + ":map_hlo_to_lhlo_op", "@com_google_absl//absl/memory", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -342,7 +427,9 @@ cc_library( ], deps = [ ":hlo", + "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) @@ -368,12 +455,14 @@ cc_library( cc_library( name = "hlo", srcs = [ + "ir/hlo_client_ops.cc", "ir/hlo_ops.cc", "ir/hlo_ops.cc.inc", "ir/hlo_ops.h.inc", "ir/hlo_utils.cc", ], hdrs = [ + "ir/hlo_client_ops.h", "ir/hlo_ops.h", "ir/hlo_utils.h", "transforms/passes.h", @@ -382,6 +471,7 @@ cc_library( includes = ["include"], deps = [ ":convert_op_folder", + ":hlo_client_ops_inc_gen", ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":xla_canonicalize_inc_gen", @@ -389,7 +479,9 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -398,6 +490,31 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "mlir_hlo_builder", + srcs = [ + "ir/mlir_hlo_builder.cc", + ], + hdrs = [ + "ir/mlir_hlo_builder.h", + ], + deps = [ + ":hlo", + ":hlo_utils", + ":type_to_shape", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "lhlo", srcs = [ @@ -418,6 +535,7 @@ cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -598,6 +716,8 @@ tf_native_cc_binary( genrule( name = "operator_writer_inc", srcs = [ + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 2c20e113956..fa029bd50d0 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -237,11 +237,9 @@ StatusOr HloFunctionImporter::ImportInstruction( case HloOpcode::kBroadcast: { // Note that the HLO broadcast is more powerful than the XLA broadcast op. // BroadcastInDim offers a superset of the HLO op's functionality. - if (!instruction->dimensions().empty()) { - attributes.push_back(builder_->getNamedAttr( - "broadcast_dimensions", - ConvertDimensions(instruction->dimensions()))); - } + attributes.push_back( + builder_->getNamedAttr("broadcast_dimensions", + ConvertDimensions(instruction->dimensions()))); MakeAndReturn(BroadcastInDimOp); } #define MakeAndReturnBatchNormOp(batch_norm_op) \ diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 8fa8b25255a..a0ce8a796cb 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -41,16 +41,23 @@ template type, llvm::makeArrayRef(data_span.data(), data_span.size())); } -llvm::SmallVector GetPermutationIfAvailable( +StatusOr> GetPermutationIfAvailable( const Shape& shape, mlir::Builder builder) { - if (!shape.has_layout() || shape.layout().minor_to_major().empty()) { - return {}; + if (!shape.has_layout() || + LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { + return llvm::SmallVector{}; } - llvm::SmallVector permutation; + if (!shape.is_static()) { + return tensorflow::errors::Internal( + "Permutations for dynamic shapes are not yet supported"); + } + llvm::SmallVector permuted_sizes; for (auto dim : llvm::reverse(shape.layout().minor_to_major())) { - permutation.push_back(dim); + permuted_sizes.push_back(shape.dimensions(dim)); } - return {AffineMap::getPermutationMap(permutation, builder.getContext())}; + return llvm::SmallVector{AffineMap::get( + permuted_sizes.size(), 0, + makeCanonicalStridedLayoutExpr(permuted_sizes, builder.getContext()))}; } } // namespace @@ -64,8 +71,10 @@ StatusOr ConvertTensorShapeToMemRefType( using mlir::MemRefType; auto dimensions = shape.dimensions(); llvm::SmallVector array(dimensions.begin(), dimensions.end()); + auto permutation_or = GetPermutationIfAvailable(shape, builder); + if (!permutation_or.ok()) return permutation_or.status(); return MemRefType::get(array, element_type_or.ValueOrDie(), - GetPermutationIfAvailable(shape, builder)); + permutation_or.ValueOrDie()); } StatusOr CreateDenseElementsAttrFromLiteral( diff --git a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc index f5e5b0ad257..bafbc1ac9a9 100644 --- a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" // Static initialization for XLA dialect registration. static mlir::DialectRegistration xla_hlo_ops; +static mlir::DialectRegistration + xla_hlo_client_ops; static mlir::DialectRegistration xla_lhlo_ops; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc new file mode 100644 index 00000000000..9056f532715 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h" + +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project + +namespace mlir { +namespace xla_hlo_client { + +template +static LogicalResult Verify(T op) { + return success(); +} + +//===----------------------------------------------------------------------===// +// BinaryOps +//===----------------------------------------------------------------------===// + +namespace { +// Gets the resulting type from a broadcast between two types. +static Type GetBroadcastType(Builder* builder, Type x, Type y, + Type element_type, + DenseIntElementsAttr broadcast_dimensions) { + auto x_ranked = x.dyn_cast(); + auto y_ranked = y.dyn_cast(); + if (!x_ranked || !y_ranked) { + return UnrankedTensorType::get(element_type); + } + + auto shape_x = x_ranked.getShape(); + auto shape_y = y_ranked.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + if (x_val == -1 || y_val == -1) { + out_shape[i] = -1; + } else { + out_shape[i] = std::max(x_val, y_val); + } + } + return RankedTensorType::get(out_shape, element_type); + } + + // Return unranked tensor for invalid broadcast dimensions. + if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { + auto old_value = out_shape[index_pair.value().getSExtValue()]; + auto new_value = shape_small[index_pair.index()]; + if (old_value != -1 && (new_value == -1 || new_value > old_value)) { + out_shape[index_pair.value().getSExtValue()] = new_value; + } + } + + return RankedTensorType::get(out_shape, element_type); +} +} // namespace + +#define BINARY_BUILDER(Op) \ + void Op::build(Builder* builder, OperationState& result, Value left, \ + Value right, DenseIntElementsAttr broadcast_dimensions) { \ + auto type = GetBroadcastType(builder, left.getType().cast(), \ + right.getType().cast(), \ + getElementTypeOrSelf(right.getType()), \ + broadcast_dimensions); \ + return Op::build(builder, result, type, left, right, \ + broadcast_dimensions); \ + } + +BINARY_BUILDER(AddOp); +BINARY_BUILDER(AndOp); +BINARY_BUILDER(Atan2Op); +BINARY_BUILDER(DivOp); +BINARY_BUILDER(MaxOp); +BINARY_BUILDER(MinOp); +BINARY_BUILDER(MulOp); +BINARY_BUILDER(OrOp); +BINARY_BUILDER(PowOp); +BINARY_BUILDER(RemOp); +BINARY_BUILDER(ShiftLeftOp); +BINARY_BUILDER(ShiftRightArithmeticOp); +BINARY_BUILDER(ShiftRightLogicalOp); +BINARY_BUILDER(SubOp); +BINARY_BUILDER(XorOp); + +#undef BINARY_BUILDER + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc.inc" + +//===----------------------------------------------------------------------===// +// xla_hlo_client Dialect Constructor +//===----------------------------------------------------------------------===// + +XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc.inc" + >(); +} + +} // namespace xla_hlo_client +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h new file mode 100644 index 00000000000..541ab0ebe3f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/DialectImplementation.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/OpDefinition.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project + +namespace mlir { +namespace xla_hlo_client { + +class XlaHloClientDialect : public Dialect { + public: + explicit XlaHloClientDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_hlo_client"; } +}; + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h.inc" + +} // namespace xla_hlo_client +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td new file mode 100644 index 00000000000..2048604915d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines "client" aligned HLO ops. +// These ops are not necessarily orthogonal or optimized for transformation but +// for ease of expression in certain cases deemed important for client +// libraries (i.e. implicit broadcasting, helper ops, etc). +// This dialect is considered to exist in addition to augment the xla_hlo +// dialect for ergonomic needs, not duplicate/replace it. +// +// The typical use of this dialect is for client libraries to be able to emit +// less constrained ops and rely on the conversion framework to lower any +// xla_hlo_client ops to canonical xla_hlo ops. +// +// See: https://www.tensorflow.org/xla/operation_semantics + +#ifndef HLO_CLIENT_OPS +#define HLO_CLIENT_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffects.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" + +def HLOClient_Dialect : Dialect { + let name = "xla_hlo_client"; + let cppNamespace = "xla_hlo_client"; +} + +class HLOClient_Op traits> : + Op { + // TODO(b/129012527) Much of this custom verification should be expressed as + // type constraints. + let verifier = [{ return Verify(*this); }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// From the client perspective, each of these support both explicit rank +// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate +// shape broadcasting. +// +// These have 1:1 correspondance with same-named ops in the xla_hlo dialect; +// however, those operations do not support broadcasting. +// +// See: +// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations +// https://www.tensorflow.org/xla/broadcasting +//===----------------------------------------------------------------------===// + +class HLOClient_BinaryElementwiseOp traits> : + HLOClient_Op { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + OptionalAttr:$broadcast_dimensions + ); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value left, Value right, " + "DenseIntElementsAttr broadcast_dimensions" + >]; + + let results = (outs HLO_Tensor); + let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; +} + +def HLOClient_AddOp : HLOClient_BinaryElementwiseOp<"add", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp; + +def HLOClient_Atan2Op : HLOClient_BinaryElementwiseOp<"atan2", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; + +def HLOClient_DivOp : HLOClient_BinaryElementwiseOp<"div", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; + +def HLOClient_MaxOp : HLOClient_BinaryElementwiseOp<"max", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; + +def HLOClient_MinOp : HLOClient_BinaryElementwiseOp<"min", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; + +def HLOClient_MulOp : HLOClient_BinaryElementwiseOp<"mul", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; + +def HLOClient_PowOp : HLOClient_BinaryElementwiseOp<"pow", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; + +def HLOClient_RemOp : HLOClient_BinaryElementwiseOp<"remainder", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; + +def HLOClient_ShiftLeftOp : HLOClient_BinaryElementwiseOp<"shift_left", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; + +def HLOClient_ShiftRightArithmeticOp : HLOClient_BinaryElementwiseOp<"shift_right_arithmetic", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; + +def HLOClient_ShiftRightLogicalOp : HLOClient_BinaryElementwiseOp<"shift_right_logical", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; + +def HLOClient_SubOp : HLOClient_BinaryElementwiseOp<"sub", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// The same description as the arithmetic binary elementwise ops applies. +//===----------------------------------------------------------------------===// + +class HLOClient_BinaryLogicalElementwiseOp : + HLOClient_BinaryElementwiseOp { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs, + OptionalAttr:$broadcast_dimensions + ); +} + +def HLOClient_AndOp: HLOClient_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; +def HLOClient_OrOp: HLOClient_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; +def HLOClient_XorOp : HLOClient_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; + +#endif // HLO_CLIENT_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index f44bb9da758..023ab46a66f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project @@ -437,8 +438,8 @@ static LogicalResult Verify(BroadcastInDimOp op) { operandRank)); } - auto dimensions = *op.broadcast_dimensions(); - auto dimensionsType = op.broadcast_dimensions()->getType(); + auto dimensions = op.broadcast_dimensions(); + auto dimensionsType = op.broadcast_dimensions().getType(); auto dimensionsRank = dimensionsType.getRank(); if (dimensionsRank != 1) { return op.emitOpError(llvm::formatv( @@ -878,6 +879,12 @@ static LogicalResult Verify(RecvOp op) { return success(); } +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// + +OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -971,6 +978,47 @@ static LogicalResult Verify(SelectOp op) { return success(); } +// Makes it such that a SelectOp that is a non-root operation in a DRR infers +// the return type based on operand type. +LogicalResult SelectOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + auto x_type = operands[1].getType(); + auto y_type = operands[2].getType(); + auto x_tensor = x_type.cast(); + auto y_tensor = y_type.cast(); + + // Check for type compatibility in the select op. This requires that the two + // non-predicate operands: + // (a) have the same element type + // (b) have compatible shapes (i.e. the same shape and/or at least one + // dynamic shape) + if (x_tensor.getElementType() != y_tensor.getElementType() || + failed(mlir::verifyCompatibleShape(x_type, y_type))) { + return emitOptionalError(location, "incompatible operand types: ", x_type, + " and ", y_type); + } + + // TODO(lucyfox): Support output shape inference when operands have compatible + // shapes. (The output shape should be the most general of the operand shapes + // at each dimension.) For now, handle the straightforward cases and fail + // otherwise. When this is fully implemented, this logic should move into + // reusable functionality in MLIR Core. + Type output_type; + if (x_type == y_type || !x_tensor.hasRank()) { + output_type = x_type; + } else if (!y_tensor.hasRank()) { + output_type = y_type; + } else { + return emitOptionalError(location, + "currently unsupported operand types: ", x_type, + " and ", y_type); + } + inferredReturnTypes.assign({output_type}); + return success(); +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// @@ -1488,5 +1536,40 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { os << ""; } +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// + +LogicalResult deriveShapeFromFirstOperand( + OpBuilder* builder, Operation* op, + SmallVectorImpl* reifiedReturnShapes) { + Value operand = op->getOperand(0); + ShapedType operand_type = operand.getType().dyn_cast(); + if (!operand_type) { + op->emitOpError() << "first operand is not a shaped type"; + return failure(); + } + auto loc = op->getLoc(); + SmallVector shape_values; + shape_values.reserve(operand_type.getRank()); + auto shape_scalar_type = builder->getIntegerType(64); + for (auto element : llvm::enumerate(operand_type.getShape())) { + if (element.value() == ShapedType::kDynamicSize) { + Value dim = builder->create(loc, operand, element.index()); + shape_values.push_back( + builder->create(loc, dim, shape_scalar_type)); + } else { + shape_values.push_back(builder->create( + loc, builder->getI64IntegerAttr(element.value()))); + } + } + *reifiedReturnShapes = + SmallVector{builder->create( + loc, + RankedTensorType::get({operand_type.getRank()}, shape_scalar_type), + shape_values)}; + return success(); +} + } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index d0bc9619db9..1a864507253 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -28,6 +28,8 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project namespace mlir { @@ -72,6 +74,22 @@ class TokenType : public Type::TypeBase { static bool kindof(unsigned kind) { return kind == HLOTypes::Token; } }; +// Shape derivation function that computes the shape of the result based on +// the first argument. For a 2-dimensional input tensor, this produces IR of +// the form +// +// %0 = dim %arg0, 0 : memref +// %1 = index_cast %0 : index to i64 +// %2 = dim %arg0, 1 : memref +// %3 = index_cast %2 : index to i64 +// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) +// : (i64, i64) -> tensor<2xi64> +// +// and returns %4 as the shape value. +LogicalResult deriveShapeFromFirstOperand( + OpBuilder *builder, Operation *op, + SmallVectorImpl *reifiedReturnShapes); + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 8fe7bb9f58c..d85a44eca10 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the operation definition file for XLA. +// This is the operation definition file for XLA HLO ops which map to the +// traditional definition in xla_data.proto (or are aligned with the goals +// thereof). +// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto #ifndef HLO_OPS #define HLO_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffects.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" @@ -37,65 +42,12 @@ class HLO_Op traits> : let verifier = [{ return Verify(*this); }]; } -//===----------------------------------------------------------------------===// -// XLA type definitions. -//===----------------------------------------------------------------------===// - -// Token type. -def HLO_Token : Type()">, "token">; - -// Any integer tensor types -def HLO_IntTensor : TensorOf<[HLO_Int]>; - -// Any floating-point tensor types -def HLO_FpTensor : TensorOf<[AnyFloat]>; - -def HLO_PredTensor : TensorOf<[HLO_Pred]>; - -def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def HLO_ComplexTensor : TensorOf<[AnyComplex]>; - -def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; - -def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; - -// Dynamic representation of a shape vector as a tensor. Ideally this would be -// an index type (as it stores indices) but that is currently disallowed in -// MLIR. -def HLO_DimensionTensor : ShapedContainerType< - [AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, - "a 1D tensor of dimensions">; - -// In general, static shaped tensor constraints should be avoided unless -// it is for a legacy op which is only correct with static shapes. -def HLO_StaticShapeTensor : StaticShapeTensorOf<[ - AnyFloat, AnySignlessInteger, AnyComplex]>; - -//===----------------------------------------------------------------------===// -// XLA combined type definitions. -//===----------------------------------------------------------------------===// - -// Any integer or floating-point tensor types -def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; - -// Any integer or predicate tensor types -def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; - -// Any floating-point or complex tensor types -def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; - -// Any int, floating-point or complex tensor types -def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; - -// Any pred, int or floating-point tensor types -def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; - //===----------------------------------------------------------------------===// // XLA nullary op definitions. //===----------------------------------------------------------------------===// -def HLO_ConstOp : HLO_Op<"constant", [NoSideEffect]>, BASE_HLO_ConstOp { +def HLO_ConstOp : HLO_Op<"constant", [ConstantLike, NoSideEffect]>, + BASE_HLO_ConstOp { let arguments = (ins ElementsAttr:$value ); @@ -126,15 +78,41 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { let hasCustomHLOConverter = 1; } +def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { + string summary = "Create Token operator"; + + string description = [{ + Produces a HLO token. Tokens are used for ordering side-effecting perations. + This is exported to HLO as an AfterAll operation with no operands to + generate a token. + }]; + + let results = (outs HLO_Token:$output); +} + //===----------------------------------------------------------------------===// // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions class HLO_UnaryElementwiseOp traits, - Type TensorType>: HLO_Op { - + Type TensorType>: HLO_Op { let arguments = (ins TensorType:$operand); let results = (outs TensorType); + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents( + MLIRContext* context, Optional location, + ValueRange operands, ArrayRef attributes, + RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + }]; } // Abs supports complex to real, so element type is not guaranteed to match. @@ -273,11 +251,11 @@ def HLO_RealOp: HLO_Op< // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryElementwiseOp traits> : - HLO_Op { + HLO_Op { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); let builders = [OpBuilder< @@ -285,6 +263,20 @@ class HLO_BinaryElementwiseOp traits> : "DenseIntElementsAttr broadcast_dimensions" >]; + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + }]; + let results = (outs HLO_Tensor); let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; @@ -336,7 +328,7 @@ class HLO_BinaryLogicalElementwiseOp : let arguments = (ins HLO_PredOrIntTensor:$lhs, HLO_PredOrIntTensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); } @@ -619,7 +611,7 @@ def HLO_CompareOp: HLO_Op<"compare", let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions, + OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); let builders = [OpBuilder< @@ -809,7 +801,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_dimensions, - BroadcastDimAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); let results = (outs HLO_Tensor); @@ -908,6 +900,7 @@ def HLO_ConvOp : HLO_Op<"conv", [NoSideEffect]>, BASE_HLO_ConvOp { def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { let arguments = (ins HLO_Tensor); let results = (outs HLO_Tensor); + let hasFolder = 1; } def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", @@ -942,10 +935,11 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { } def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", ElementsAttr>] > { + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { let description = "Structure of dimension information for dot product"; } @@ -1087,7 +1081,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [NoSideEffect]>, BASE_HLO_ScatterOp { } // TODO(jpienaar): Add broadcastable trait. -def HLO_SelectOp: HLO_Op<"select", [NoSideEffect]>, BASE_HLO_SelectOp { +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods]>, BASE_HLO_SelectOp { let arguments = (ins HLO_PredTensor:$pred, HLO_Tensor:$on_true, diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 64303e86fe0..8dee4d0eb69 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -18,14 +18,68 @@ limitations under the License. include "mlir/IR/OpBase.td" -def HLO_Int : IntOfWidths<[8, 16, 32, 64]>; +def HLO_Int : SignlessIntOfWidths<[8, 16, 32, 64]>; def HLO_Pred : TypeAlias; // The broadcasting dimensions correspond to a tuple that describes how a // smaller rank shape is broadcast into a larger rank shape. For example, // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means // matching the matrix to dimensions 1 and 2 of the cuboid. -def BroadcastDimAttr : OptionalAttr; +defvar BroadcastDimAttr = I64ElementsAttr; + +//===----------------------------------------------------------------------===// +// XLA on tensors type definitions. +//===----------------------------------------------------------------------===// + +// Token type. +def HLO_Token : Type()">, "token">; + +// Any integer tensor types +def HLO_IntTensor : TensorOf<[HLO_Int]>; + +// Any floating-point tensor types +def HLO_FpTensor : TensorOf<[AnyFloat]>; + +def HLO_PredTensor : TensorOf<[HLO_Pred]>; + +def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +def HLO_ComplexTensor : TensorOf<[AnyComplex]>; + +def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; + +def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; + +// Dynamic representation of a shape vector as a tensor. Ideally this would be +// an index type (as it stores indices) but that is currently disallowed in +// MLIR. +def HLO_DimensionTensor : ShapedContainerType< + [AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, + "a 1D tensor of dimensions">; + +// In general, static shaped tensor constraints should be avoided unless +// it is for a legacy op which is only correct with static shapes. +def HLO_StaticShapeTensor : StaticShapeTensorOf<[ + AnyFloat, AnySignlessInteger, AnyComplex]>; + +//===----------------------------------------------------------------------===// +// XLA on tensors combined type definitions. +//===----------------------------------------------------------------------===// + +// Any integer or floating-point tensor types +def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; + +// Any integer or predicate tensor types +def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; + +// Any floating-point or complex tensor types +def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; + +// Any int, floating-point or complex tensor types +def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; + +// Any pred, int or floating-point tensor types +def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; //===----------------------------------------------------------------------===// // XLA nullary op definitions. @@ -703,8 +757,8 @@ class BASE_HLO_AllToAllOp { AllToAll is a collective operation that sends data from all cores to all cores. It has two phases: - The scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are - scattered to all cores, e.g., the i-th block is send to the i-th core. + number of blocks along the `split_dimension`, and the blocks are + scattered to all cores, e.g., the i-th block is sent to the i-th core. - The gather phase. Each core concatenates the received blocks along the `concat_dimension`. @@ -718,7 +772,7 @@ class BASE_HLO_AllToAllOp { will be concatenated in the same order of 1, 2, 3. Then, another AllToAll will be applied within replicas 4, 5, 0, and the concatenation order is also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one - group, in the concatenation order of their appearance. + group, and the concatenation order is the numerical order (0, 1, 2, ...). Prerequisites: - The dimension size of the operand on the split_dimension is divisible by diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc index 130acaf1acb..0143e781549 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -22,14 +22,16 @@ limitations under the License. namespace mlir { namespace xla { -DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) { +DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, + bool allow_empty) { TensorType xType = x.getType().dyn_cast(); TensorType yType = y.getType().dyn_cast(); - if (xType == yType || !xType || !yType) return {}; + if (!xType || !yType) return {}; + if (allow_empty && xType == yType) return {}; // If the shapes have the same rank, then there is nothing to do. auto xRank = xType.getRank(), yRank = yType.getRank(); - if (xRank == yRank) return {}; + if (allow_empty && xRank == yRank) return {}; // Otherwise if the ranks of the inputs don't match, TensorFlow automatically // reshapes the smaller by padding with dimensions of size 1 as a prefix. In diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 3e3570f5b54..84ea3a1e1a8 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -28,9 +28,12 @@ namespace xla { // Computes the broadcast dimensions attr for an elementwise binary operator // between two ranked tensors. +// If `allow_empty` is true, then null can be returned to mean that the +// broadcast is an "identity". mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, mlir::Value x, - mlir::Value y); + mlir::Value y, + bool allow_empty = true); /// Get a constant splat for the given value type. template diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td index 97b29bf0851..c6ea1fe9749 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td @@ -32,6 +32,9 @@ def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; +def BinBroadcastDimensionsNonEmpty : NativeCodeCall< + "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; + // Here, the element type can be any integer or float type. But, note that only // 32 bit integers are supported for the value. class GetScalarOfType : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index 1a07b1a45f3..f9cb2284526 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 3a675f20d92..a37c530532d 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -19,6 +19,7 @@ limitations under the License. #define LHLO_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffects.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def LHLO_Dialect : Dialect { @@ -97,6 +98,8 @@ def LHLO_NegOp: LHLO_UnaryElementwiseOp<"neg">, BASE_HLO_NegOp; def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; +def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; + def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; @@ -111,7 +114,7 @@ class LHLO_BinaryElementwiseOp traits> : LHLO_Buffer:$lhs, LHLO_Buffer:$rhs, LHLO_Buffer:$out, - BroadcastDimAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); } @@ -152,6 +155,29 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [ let regions = (region SizedRegion<1>:$body); } + +def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ + NoSideEffect, + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_ReduceWindowOp { + + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$init_value, + LHLO_Buffer:$out, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the input dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding + ); + + let regions = (region SizedRegion<1>:$body); +} + //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// @@ -175,7 +201,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { LHLO_Buffer:$lhs, LHLO_Buffer:$rhs, LHLO_PredBuffer:$out, - BroadcastDimAttr:$broadcast_dimensions, + OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); } @@ -313,6 +339,21 @@ def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { ); } +def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", + [NoSideEffect]>, BASE_HLO_SelectAndScatterOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$source, + LHLO_Buffer:$init_value, + LHLO_Buffer:$out, + OptionalAttr:$window_dimensions, + OptionalAttr:$window_strides, + OptionalAttr:$padding + ); + + let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); +} + def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { let arguments = (ins LHLO_Buffer:$operand, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc new file mode 100644 index 00000000000..1573810bc90 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -0,0 +1,139 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { + +static std::string GetMlirOpName(HloOpcode opcode) { + std::string op_name = HloOpcodeString(opcode); + absl::c_replace(op_name, '-', '_'); + return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." + + op_name; +} + +static std::string ToString(mlir::Type ty) { + std::string str; + llvm::raw_string_ostream sstream(str); + ty.print(sstream); + sstream.flush(); + return str; +} + +// Returns 1D 64-bit dense elements attribute with the given values. +static mlir::DenseIntElementsAttr GetI64ElementsAttr( + absl::Span values, mlir::Builder* builder) { + auto ty = mlir::RankedTensorType::get({static_cast(values.size())}, + builder->getIntegerType(64)); + llvm::SmallVector mlir_values; + mlir_values.reserve(values.size()); + for (const auto& value : values) { + mlir_values.push_back(value); + } + return mlir::DenseIntElementsAttr::get(ty, mlir_values); +} + +MlirHloBuilder::~MlirHloBuilder() = default; + +StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { + mlir::Type ty = val.getType(); + auto shape = std::make_unique(TypeToShape(ty)); + if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("unsupported type: %s", ToString(ty).c_str()); + } + + int64 handle = reinterpret_cast(val.getAsOpaquePointer()); + handle_to_shape_[handle] = std::move(shape); + return XlaOp(handle, this); +} + +StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, + XlaOp operand, + int64 inferred_dimension) { + TF_RETURN_IF_ERROR(first_error()); + + if (inferred_dimension != -1) + return Unimplemented("inferred_dimension not yet supported for Reshape op"); + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::Value value = GetValue(operand); + auto op = builder_.create(loc_, ty, value); + return MakeXlaOp(op.getResult()); +} + +StatusOr MlirHloBuilder::InDimBroadcast( + const Shape& shape, XlaOp operand, + absl::Span broadcast_dimensions) { + TF_RETURN_IF_ERROR(first_error()); + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::Value value = GetValue(operand); + auto op = builder_.create( + loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_)); + return MakeXlaOp(op.getResult()); +} + +XlaOp MlirHloBuilder::BinaryOpNoBroadcast( + HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs, + absl::optional direction) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (direction.has_value()) + return Unimplemented("direction attribute not yet supported"); + return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{}); + }); +} + +StatusOr MlirHloBuilder::AddOpWithShape( + HloOpcode opcode, const Shape& shape, absl::Span operands) { + return CreateOp(GetMlirOpName(opcode), shape, + llvm::makeArrayRef(operands.data(), operands.size()), + /*attributes=*/{}); +} + +StatusOr MlirHloBuilder::CreateOp( + const std::string& op_name, const Shape& shape, + llvm::ArrayRef operands, + llvm::ArrayRef attributes) { + llvm::SmallVector operand_values; + operand_values.reserve(operands.size()); + for (XlaOp xla_op : operands) { + operand_values.push_back(GetValue(xla_op)); + } + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes); + mlir::Operation* op = builder_.createOperation(state); + return MakeXlaOp(op->getResult(0)); +} + +StatusOr MlirHloBuilder::GetShapePtr(XlaOp op) const { + TF_RETURN_IF_ERROR(first_error()); + TF_RETURN_IF_ERROR(CheckOpBuilder(op)); + auto it = handle_to_shape_.find(op.handle()); + if (it == handle_to_shape_.end()) { + return InvalidArgument("No XlaOp with handle %d", op.handle()); + } + return it->second.get(); +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h new file mode 100644 index 00000000000..9bebbc025a5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { + +// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder +// interface. +// +// Requires that all XlaOp arguments are either returned by any of the builder +// method or constructed using MakeXlaOp method in this builder. +// +// TODO(hinsu): Support more ops and utility functions to set special attributes +// like OpMetadata and Sharding. +class MlirHloBuilder : public XlaBuilder { + public: + // Constructs builder for the given function. New operations are added to the + // beginning of the function, if it is non empty and has a block. + explicit MlirHloBuilder(mlir::FuncOp func) + : XlaBuilder(func.getName().str()), + builder_(&func.getBody()), + loc_(builder_.getUnknownLoc()) {} + + // TODO(hinsu): Add a constructor to build a new MLIR function from scratch + // and override Build methods. + + MlirHloBuilder(const MlirHloBuilder&) = delete; + MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; + + ~MlirHloBuilder() override; + + // Wraps the given MLIR value under an XlaOp instance. Note that all HLO + // operations returns exactly one result therefore each op has an XlaOp + // wrapping result of the op. + // + // Returns an error if the HLO dialect doesn't support type of the given + // value. + StatusOr MakeXlaOp(mlir::Value val); + + // Returns value corresponding to the given op. + // + // Requires that the op was created by this builder. + mlir::Value GetValue(XlaOp op) { + void* ptr = reinterpret_cast(op.handle()); + return mlir::Value::getFromOpaquePointer(ptr); + } + + // Sets location for newly built ops, until reset. + void SetLocation(mlir::Location loc) { loc_ = loc; } + + // Update insertion point so that newly built ops are inserted before the + // given op in order, until reset. + void setInsertionPoint(mlir::Operation* op) { + builder_.setInsertionPoint(op); + } + + // Returns the shape of the given op. + StatusOr GetShapePtr(XlaOp op) const override; + + private: + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension) override; + + StatusOr InDimBroadcast( + const Shape& shape, XlaOp operand, + absl::Span broadcast_dimensions) override; + + XlaOp BinaryOpNoBroadcast( + HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs, + absl::optional direction) override; + + StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands) override; + + // Creates HLO dialect op and returns the result as an XlaOp. + StatusOr CreateOp(const std::string& op_name, const Shape& shape, + llvm::ArrayRef operands, + llvm::ArrayRef attributes); + + mlir::OpBuilder builder_; + mlir::Location loc_; + + absl::flat_hash_map> handle_to_shape_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index b73cfcfa538..18a29968600 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -74,3 +74,11 @@ func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 { // CHECK: return %[[ARG0]] return %2 : i32 } + +// CHECK-LABEL: func @fold_copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { + // CHECK: return [[ARG]] + %0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index be6f0e6a949..2aeb5f1041d 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure +// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -6,69 +6,48 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_result = "xla_hlo.exp"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// CHECK-LABEL: func @func_op -func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) - %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) - return %0 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () -} - -// ----- - // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - %1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> + %1 = xla_hlo.max %arg0, %arg1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) - %2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32> + %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) - %3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> + %3 = xla_hlo.min %arg0, %arg1 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) - %4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32> + %4 = xla_hlo.sub %arg1, %3 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) - %5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[RESULT]]) + %5 = xla_hlo.mul %2, %4 : tensor<4xf32> + // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> + // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () + // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> return %5 : tensor<4xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } // ----- -// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store -func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: tensor, %arg2: memref) { - %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor, tensor) -> tensor - tensor_store %0, %arg2 : memref - return -} -// CHECK: (%[[NEW_ARG0:.*]]: memref, %[[NEW_ARG1:.*]]: memref, %[[RESULT:.*]]: memref) -// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref -// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref, memref, memref) -> () -// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref, memref) -> () -// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref -// CHECK: "xla_lhlo.terminator"() : () -> () - -// ----- - // CHECK-LABEL: func @fusion func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { + // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> @@ -78,9 +57,11 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "xla_hlo.mul"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %{{.*}}) + // CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> + // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> () } @@ -92,7 +73,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -104,7 +85,19 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.exp"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.exp"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// CHECK-LABEL: func @log +func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.log"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -119,7 +112,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -133,7 +126,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // CHECK-NEXT: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } @@ -146,7 +139,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -183,7 +176,7 @@ func @dyn_broadcast(%operand: memref) { func @iota(%result: memref<10xi32>) { %tensor_result = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // CHECK-NEXT: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } @@ -195,7 +188,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -207,7 +200,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -231,7 +224,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.cos"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.cos"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.cos"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -243,7 +236,19 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.neg"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.neg"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.neg"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// CHECK-LABEL: func @rsqrt +func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.rsqrt"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -255,7 +260,19 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// CHECK-LABEL: func @sqrt +func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.sqrt"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -267,7 +284,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "xla_hlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -280,7 +297,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index 61add8c4389..1f4c9c6ea6c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -77,6 +77,17 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_rsqrt +func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { + %tensor_result = "xla_hlo.rsqrt"(%operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: linalg.generic + // CHECK: rsqrt + return %tensor_result : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_sub func @float_sub(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { @@ -121,6 +132,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_log +func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: log + %0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_ceil func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -192,13 +213,12 @@ func @int_cmp(%lhs: tensor<2x2xi32>, // ----- // CHECK-LABEL: func @copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { %0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) return %0 : tensor<2x4x8xf32> } -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32 +// CHECK: return [[ARG]] : tensor<2x4x8xf32> // ----- @@ -231,7 +251,7 @@ func @broadcast(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { // ----- -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (0)> +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir new file mode 100644 index 00000000000..f2dff2c9956 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -0,0 +1,91 @@ +// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU %s | FileCheck %s --dump-input-on-failure + +// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + +// CHECK-LABEL: abs +// expected-error@+1 {{unsupported device}} +func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + + // return %[[RESULT]] + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: unknown_op +func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: tf.CustomTestOp + // expected-remark@+1 {{constant 20}} + %0 = "tf.CustomTestOp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: dynamic_operand +func @dynamic_operand(%arg0: tensor) -> tensor { + // CHECK: tf.Abs + // expected-remark@+1 {{lowering requires static shaped operands}} + %0 = "tf.Abs"(%arg0) : (tensor) -> tensor + + return %0 : tensor +} + +// CHECK-LABEL: multiple_dialect_ops +func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: xla_hlo.neg + %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: xla_hlo.abs + %1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32> + + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: binary_op +func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32> + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: binary_op_broadcast +func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { + // CHECK: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32> + // CHECK: %[[RESHAPE0:.*]] = "xla_hlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG0:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + + // CHECK: %[[RESHAPE1:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG1:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + + // CHECK: %[[RESULT:.*]] = xla_hlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32> + // CHECK: return %[[RESULT]] : tensor<4x4x4xf32> + + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32> + return %0: tensor<4x4x4xf32> +} + +// CHECK-LABEL: func @ternary_op +func @ternary_op(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @convert +func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { + // CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// TODO(hinsu): Add a test with variant type once one of the ops supporting +// the type is whitelisted. It should be rejected with unsupported type remark. + +// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the +// type is whitelisted. Unsigned types are not yet added to the HLO dialect so +// it should return an error. See b/130356985 + +// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is +// available but doesn't support this instance. +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index d80722e2865..b759fe593c2 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -523,17 +523,17 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten } // CHECK-LABEL: func @shift_right_unsigned -func @shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> { +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> - return %0 : tensor<4x!tf.uint8> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> } // CHECK-LABEL: func @broadcast_shift_right_unsigned -func @broadcast_shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> { +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> - return %0 : tensor<2x4x!tf.uint8> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + return %0 : tensor<2x4xui8> } // CHECK-LABEL: func @and @@ -1133,8 +1133,8 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @infeed_dequeue_tuple func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { -// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token -// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[AFTER_ALL]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple, tensor<4xf32>>, !xla_hlo.token> +// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token +// CHECK: [[INFEED:%.*]] = "xla_hlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!xla_hlo.token) -> tuple, tensor<4xf32>>, !xla_hlo.token> // CHECK: [[INFEED_VAL:%.*]] = "xla_hlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple, tensor<4xf32>>, !xla_hlo.token>) -> tuple, tensor<4xf32>> // CHECK: [[RES_1:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple, tensor<4xf32>>) -> tensor<3xi32> // CHECK: [[RES_2:%.*]] = "xla_hlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple, tensor<4xf32>>) -> tensor<4xf32> @@ -1228,6 +1228,80 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten return %0: tensor<3x5xf32> } +//===----------------------------------------------------------------------===// +// MatrixBandPart op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: matrix_band_part +// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) +func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { + // CHECK: %[[M:.*]] = xla_hlo.constant dense<64> : tensor + // CHECK: %[[N:.*]] = xla_hlo.constant dense<64> : tensor + + // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor + // CHECK: %[[A:.*]] = "xla_hlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %[[B:.*]] = "xla_hlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor, tensor, tensor) -> tensor + + // CHECK: %[[C:.*]] = "xla_hlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %[[D:.*]] = "xla_hlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor, tensor, tensor) -> tensor + + // CHECK: %[[E:.*]] = "xla_hlo.convert"(%[[B]]) : (tensor) -> tensor + // CHECK: %[[F:.*]] = "xla_hlo.neg"(%[[E]]) : (tensor) -> tensor + + // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> + // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> + // CHECK: %[[OFFSET:.*]] = xla_hlo.sub %[[X]], %[[Y]] : tensor<64x64xbf16> + // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<*xi1> + + // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor + // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<*xi1> + + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<*xi1> + + // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> + // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) + // CHECK: return %[[R]] + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> +} + +// CHECK-LABEL: matrix_band_part_2 +// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) +func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<12x24x48xbf16> { + // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16> + // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> + // CHECK: %[[OFFSET:.*]] = xla_hlo.sub %[[X]], %[[Y]] : tensor<24x48xbf16> + + // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<*xi1> + + // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor + // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<*xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : tensor<*xi1> + + // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> + // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) + // CHECK: return %[[R]] + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor, tensor) -> tensor<12x24x48xbf16> + return %0 : tensor<12x24x48xbf16> +} + +// CHECK-LABEL: matrix_band_part_3 +// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) +func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { + // CHECK: "tf.MatrixBandPart" + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor, tensor) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// CHECK-LABEL: matrix_band_part_4 +// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor, %[[UPPER:.*]]: tensor) +func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<24x48xbf16> { + // This one should lower. + // CHECK-NOT: "tf.MatrixBandPart" + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor, tensor) -> tensor<24x48xbf16> + return %0 : tensor<24x48xbf16> +} + //===----------------------------------------------------------------------===// // MaxPool op legalizations. //===----------------------------------------------------------------------===// @@ -1319,8 +1393,8 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tenso // CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { // CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple, tensor<4xf32>> -// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token -// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[AFTER_ALL]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token +// CHECK: [[TOKEN:%.*]] = "xla_hlo.create_token"() : () -> !xla_hlo.token +// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () return } @@ -2415,6 +2489,25 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { return } +// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( +// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> +func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { + // StridedSlice gets input[8:10], which is same as input[8:10, ...] + // The start_indices, limit_indices, and strides attribute of xla_hlo.slice + // reflect the canonicalized slice. + %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> + %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> + // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> + // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> + return %0 : tensor<2x16x2xf32> +} + //===----------------------------------------------------------------------===// // Reduction op legalizations. @@ -3247,6 +3340,121 @@ func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> return %0: tensor<4x128x1024xf32> } +// CHECK-LABEL: strided_slice_grad_shrink_axis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> +func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { + // Input to StridedSlice was of shape 4x8xf32 + // Strided slice gets input[2:3, 0:8] + // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 + // which is the shape of gradient. + // StridedSliceGrad would reshape the gradient to 1x8xf32 and + // then pad to match the shape of input 4x8xf32. + + %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32> + // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> + + // CHECK: return [[PAD]] : tensor<4x8xf32> + return %0 : tensor<4x8xf32> +} + +// CHECK-LABEL: strided_slice_grad_new_axis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> +func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { + // Input to StridedSlice was of shape 8xf32 + // Strided slice gets input[tf.new_axis, 2:4] + // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is + // 1x2xf32 which is the shape of gradient. + // StridedSliceGrad would reshape the gradient to 2xf32 and + // then pad to match the shape of input 4x8xf32. + + %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32> + // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> + + // CHECK: return [[PAD]] : tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK-LABEL: strided_slice_grad_ellipsis_mask +// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> +func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { + // Input to StridedSlice was of shape 4x4x8xf32 + // Strided slice gets input[2:4, ...] + // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and + // dim#2, ignoring begin and end indices for these dimensions. So the output + // is 2x4x8xf32 which is the shape of gradient. + // StridedSliceGrad would pad the gradient to match the shape of + // input 4x4x8xf32. + + %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) + %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> + // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZEROS]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> + // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> + + // CHECK: return [[PAD]] : tensor<4x4x8xf32> + return %0 : tensor<4x4x8xf32> +} + + +// CHECK-LABEL: strided_slice_grad_all_masks +// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> +func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { + // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] + // New axis mask is at index 1 and 6 of sparse spec, so + // new_axis_mask = 2^1 + 2^6 = 66 + // The ellipsis mask is applied to dim #1, #2 of input i.e, we get + // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] + // The StridedSliceGrad op would propogate the gradient for the sliced tensor + // to the original input tensor by padding with zeroes. + + %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) + %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) + %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) + + // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> + // CHECK: [[ZERO:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // The edge_padding_low, edge_padding_high and interior_padding attributes of + // xla_hlo.pad would reflect the padding required to get the shape of the + // input of StridedSlice op. + // CHECK: [[PAD:%.*]] = "xla_hlo.pad"([[RESHAPE]], [[ZERO]]) + // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> + // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> + // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> + %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> + + // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> + return %0 : tensor<2x4x8x16x32x64xf32> +} + // CHECK-LABEL: @tensor_scatter_update func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { // CHECK: "xla_hlo.scatter"(%arg0, %arg1, %arg2) ( { @@ -3435,3 +3643,115 @@ func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> return %0 : tensor<4x16xf32> } + +// CHECK-LABEL: xla_dynamic_update_slice2 +func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { + // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor + // CHECK: [[DUS:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK: return [[DUS]] + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +//===----------------------------------------------------------------------===// +// Cumsum op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @cumsum_static +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = xla_hlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "xla_hlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = xla_hlo.add [[A]], [[B]] : tensor + // CHECK: "xla_hlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "xla_hlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[CONVERT_REDUCE]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @cumsum_exclusive +func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: "tf.Cumsum" + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @cumsum_reverse +func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: "tf.Cumsum" + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @cumsum_dynamic +func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "tf.Cumsum" + %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// tf.BatchMatMulV2 op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @batchmatmulv2_broadcast_singleton_dimension +func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + // CHECK: return [[BDST]] : tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +// CHECK-LABEL: func @batchmatmulv2_lhs_batch +func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + // CHECK: return [[BDST]] : tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +// CHECK-LABEL: func @batchmatmulv2_rhs_batch +func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>) -> tensor<3x4x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>) -> tensor<3x2x4xf32> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + // CHECK: return [[BDST]] : tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +// CHECK-LABEL: func @batchmatmulv2_dynamic +func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "tf.BatchMatMulV2" + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor, tensor) -> tensor + return %0 : tensor +} + diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir new file mode 100644 index 00000000000..35546594ccb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir @@ -0,0 +1,93 @@ +// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: func @remove_simple +func @remove_simple(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @remove_without_dealloc +func @remove_without_dealloc(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @replace_dependency +func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @keep_copies +func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_not_be_removed +func @must_not_be_removed(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_first +func @must_be_removed_first(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_second +func @must_be_removed_second(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exp"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exp"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 7f7e37ebe66..0a48cbd372f 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -125,3 +125,55 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP: subf // PLOOP: linalg.generic // PLOOP: exp + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, + %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { + %temp_result = alloc() {temp = true} : memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + dealloc %temp_result : memref<6x6x6x6xf32> + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @fusion_4d +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK: loop.for {{.*}} step %[[C1]] +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion_4d +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: loop.for {{.*}} step %[[C2]] +// TILED: loop.for {{.*}} step %[[C3]] +// TILED-NOT: loop.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion_4d +// PLOOP-NOT: linalg.generic +// PLOOP: loop.parallel +// PLOOP-NOT: loop.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index d43ca3b6bb2..5d0c767a716 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -102,6 +102,20 @@ func @exp(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @log +func @log(%input: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.log"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @copy func @copy(%input: memref<2x4x8xf32>, %result: memref<2x4x8xf32>) { @@ -210,6 +224,7 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { // ----- +// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { @@ -218,9 +233,8 @@ func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { : (memref, memref<7x10x6xf32>) -> () return } -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32): -// CHECK-NEXT: %[[CONST:.*]] = load %{{.*}} : memref +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[CONST]] : f32 // ----- @@ -401,6 +415,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @rsqrt +func @rsqrt(%input: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.rsqrt"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @sign func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -416,6 +444,20 @@ func @sign(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @sqrt +func @sqrt(%input: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.sqrt"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @tanh func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir new file mode 100644 index 00000000000..3317d24d820 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -0,0 +1,127 @@ +// RUN: tf-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure + +func @reduce(%arg: memref<100x10x5xf32>, + %init: memref, + %result: memref<100x5xf32>) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () + return +} +// CHECK-LABEL: func @reduce( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100x10x5xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<100x5xf32>) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C5:%.*]] = constant 5 : index +// CHECK-DAG: [[C10:%.*]] = constant 10 : index +// CHECK-DAG: [[C100:%.*]] = constant 100 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> +// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: loop.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: loop.yield + +// ----- + +func @reduce_no_outer_loop(%arg: memref<100xf32>, + %init: memref, + %result: memref<1xf32>) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} + : (memref<100xf32>, memref, memref<1xf32>) -> () + return +} +// CHECK-LABEL: func @reduce_no_outer_loop( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100xf32>, +// CHECK-SAME: [[ELEM_TO_REDUCE_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<1xf32>) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C100:%.*]] = constant 100 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[I:%.*]]) = ([[C0]]) +// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} +// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: loop.reduce.return [[ACC_RESULT]] +// CHECK: } +// CHECK: loop.yield +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] + +// ----- + +func @dynamic_reduce(%arg: memref, + %init: memref, + %result: memref) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref, memref, memref) -> () + return +} +// CHECK-LABEL: func @dynamic_reduce( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], 0 : memref +// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref +// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref +// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: loop.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: loop.yield diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 9f181d574c0..2953fc84d71 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -146,7 +146,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) - // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) : (memref, memref<1x2x3xi32>) -> () + "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index ce70c7896c1..037eded9ba6 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -102,7 +102,7 @@ func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { // CHECK-LABEL: func @broadcast_in_dim_zero_rank func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) : (tensor) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -446,7 +446,7 @@ func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit integer or floating-point values, but got 'tensor>'}} + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer or floating-point values, but got 'tensor>'}} %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -477,6 +477,31 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso // ----- +// CHECK-LABEL: func @select_cast_compatible_types +func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg2: tensor<2x3xi32>) -> tensor<*xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// ----- + +func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %arg2: tensor) -> tensor { + // TODO(lucyfox): Update once this is supported. + // expected-error@+1 {{currently unsupported operand types: 'tensor<2x?xi32>' and 'tensor'}} + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @select_scalar_x_y +func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -485,18 +510,16 @@ func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: // ----- -// TODO(jpienaar): Re-enable post updating select function verify. func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // should-be-error@+1 {{on_true type (tensor<2x4xi32>) does not match on_false type (tensor<2x3xi32>)}} + // expected-error@+1 {{incompatible operand types: 'tensor<2x4xi32>' and 'tensor<2x3xi32>'}} %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } // ----- -// TODO(jpienaar): Re-enable post updating select function verify. func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // should-be-error@+1 {{on_true type (tensor<2x3xf32>) does not match on_false type (tensor<2x3xi32>)}} + // expected-error@+1 {{incompatible operand types: 'tensor<2x3xf32>' and 'tensor<2x3xi32>'}} %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -731,7 +754,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // ----- func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit integer values, but got 'tensor<4xf32>'}} + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer values, but got 'tensor<4xf32>'}} %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index aac4e613358..8af27bb586a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -160,6 +160,16 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { // ----- +// CHECK: HloModule +func @main() -> !xla_hlo.token { + %0 = "xla_hlo.create_token"() : () -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// CHECK: ROOT [[TOKEN:%.*]] = token[] after-all() + +// ----- + // CHECK: HloModule func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> { %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 38e58be8e64..01a24c06d2c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -27,7 +27,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} // CHECK-NEXT: %6 = xla_hlo.mul %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> @@ -36,7 +36,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} // CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> @@ -45,13 +45,13 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} // CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index 1270e339d98..b5e1eaf104a 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -11,7 +11,7 @@ func @batchNormInference_2D_inner_features( %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) : (tensor) -> tensor<256xf32> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> @@ -92,3 +92,46 @@ func @batchNormInference_f16_overflow( tensor<256xf16>) -> tensor<4x256xf16> return %0 : tensor<4x256xf16> } + +// ----- +// CHECK-LABEL: @batchNormInference_dynamic_shape +// Validate that dynamic shapes are handled properly. +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func @batchNormInference_dynamic_shape( + %x: tensor, %scale: tensor, %offset: tensor, + %mean: tensor, %variance: tensor) + -> tensor { + // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor + // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor + // CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32 + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor + // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor + // CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32 + // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor + // CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32 + // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor + // CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32 + // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor + // CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32 + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.sub %[[X]], %[[MEAN_BCAST]] : tensor + // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.mul %[[X_CENTER]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.div %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 0.001 : f32, feature_index = 1 : i64} : + (tensor, tensor, tensor, tensor, + tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 29d399c68fa..cc6ca472c23 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -30,7 +30,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" -#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h" +#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -117,7 +117,7 @@ Value InsertAllocAndDealloc(Location loc, Value result, return alloc; } -template +template class HloToLhloOpConverter : public ConversionPattern { public: explicit HloToLhloOpConverter(MLIRContext* context) @@ -138,23 +138,24 @@ class HloToLhloOpConverter : public ConversionPattern { buffer_args.push_back( InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); } else { - Value shape_value = ShapeDerivation::impl::deriveShapeFromOp( - op, result.index(), &rewriter); - if (!shape_value) { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return matchFailure(); + if (failed( + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) return matchFailure(); - } buffer_args.push_back(InsertDynamicAllocAndDealloc( - op->getLoc(), result.value(), shape_value, &rewriter)); + op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } - rewriter.create(op->getLoc(), llvm::None, buffer_args, - op->getAttrs()); + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); return matchSuccess(); } }; -struct HloToLHloDynamicBroadcastInDimOpConverter +struct HloToLhloDynamicBroadcastInDimOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter } }; -struct HloToLHloReduceOpConverter +struct HloToLhloReduceOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -272,14 +273,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // "xla_lhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> -// %2 = "xla_hlo.add"(%0, %1) {name = "add"} : +// %2 = "xla_hlo.add"(%0, %1) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32> -// %4 = "xla_hlo.mul"(%2, %3) {name = "multiply"} : +// %4 = "xla_hlo.mul"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () -// }) {name = "fusion"} : () -> () +// }) : () -> () // return // } // @@ -289,14 +290,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { // "xla_lhlo.fusion"() ( { -// %0 = alloc() {temp = true} : memref<2x2xf32> +// %0 = alloc() : memref<2x2xf32> // "xla_lhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // "xla_lhlo.mul"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () -// }) {name = "fusion"} : () -> () +// }) : () -> () // return // } // } @@ -304,9 +305,9 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // FuncOp signature conversion example: // // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -// %0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32> -// %1 = xla_hlo.add %arg0, %0 {name = "maximum.47"} : tensor<4xf32> -// return %1 : tensor<4xf32> +// %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, +// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // } // // Transformed function with an extra argument for the result. The types have @@ -315,11 +316,14 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // func @func_op(%arg0: memref<4xf32>, // %arg1: memref<4xf32>, // %arg2: memref<4xf32>) { -// %0 = alloc() {temp = true} : memref<4xf32> -// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} : +// %0 = alloc() : memref<4xf32> +// %1 = alloc() : memref<4xf32> +// "xla_lhlo.max"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} : +// "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// dealloc %0 : memref<4xf32> // dealloc %1 : memref<4xf32> // "xla_lhlo.terminator"() : () -> () // } @@ -438,90 +442,47 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - HloToLHloDynamicBroadcastInDimOpConverter, + HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloFuncOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLHloReduceOpConverter, - StdToLhloReturnOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, - HloToLhloTensorStoreOpConverter + HloToLhloTensorStoreOpConverter, + StdToLhloReturnOpConverter >(context); // clang-format on } -/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block -/// argument. All uses of the buffer are replaced with the block argument. -struct RedundantCopiesRemoval : mlir::FunctionPass { - void runOnFunction() override { - llvm::SmallVector eraseList; - getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) { - auto arguments = copyOp.getOperation()->getBlock()->getArguments(); - if (std::any_of(arguments.begin(), arguments.end(), - [&](mlir::BlockArgument arg) { - return copyOp.output() == arg; - }) && - std::none_of(arguments.begin(), arguments.end(), - [&](mlir::BlockArgument arg) { - return copyOp.operand() == arg; - })) { - mlir::Value operand = copyOp.operand(); - mlir::Value output = copyOp.output(); - copyOp.erase(); - for (auto op : operand.getUsers()) { - if (!mlir::isa(op)) { - op->replaceUsesOfWith(operand, output); - } - } - auto allocOp = operand.getDefiningOp(); - if (auto deallocOp = - mlir::dyn_cast(*allocOp->getUsers().begin())) { - eraseList.push_back(deallocOp); - eraseList.push_back(allocOp); - } - } - }); - for (auto op : eraseList) { - op->erase(); - } - }; -}; - std::unique_ptr> createLegalizeToLhloPass() { return absl::make_unique(); } -std::unique_ptr> createLhloCopyRemovalPass() { - return absl::make_unique(); -} - static PassRegistration legalize_pass( "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); -static PassRegistration copies_removal_pass( - "lhlo-redundant-copies-removal", - "Legalize from HLO dialect to LHLO dialect"); - } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h deleted file mode 100644 index d2a1f47e540..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace xla_hlo { - -// This file contains implementations for shape derivation functions that, -// given some operation and a result number, produce IR that computes the -// shape of the given result at runtime based on operands of the provided -// operation. -// These should be generated at some point based on annotations on the HLO -// using the new shape dialect. While this is still in the works, we hardcode -// the expected IR here to unblock progress. -// The implementation is based on templates to allow for using these derivation -// functions in templated code. - -namespace impl { - -struct UnknownShape { - // Default shape derivation function that simply fails with a runtime error. - static Value deriveShapeFromOp(Operation* op, int operand_position, - ConversionPatternRewriter* rewriter) { - op->emitOpError() - << "dynamic result shapes cannot be derived for this operation"; - return {}; - } -}; - -struct SameShapeAsFirstOperand { - // Shape derivation function that computes the shape of the result based on - // the first argument. For a 2-dimensional input tensor, this produces IR of - // the form - // - // %0 = dim %arg0, 0 : memref - // %1 = index_cast %0 : index to i64 - // %2 = dim %arg0, 1 : memref - // %3 = index_cast %2 : index to i64 - // %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) - // : (i64, i64) -> tensor<2xi64> - // - // and returns %4 as the shape value. - static Value deriveShapeFromOp(Operation* op, int result_postion, - ConversionPatternRewriter* rewriter) { - Value operand = op->getOperand(0); - ShapedType operand_type = operand.getType().dyn_cast(); - if (!operand_type) { - op->emitOpError() << "first operand has no shaped type"; - return {}; - } - auto loc = op->getLoc(); - SmallVector shape_values; - shape_values.reserve(operand_type.getRank()); - auto shape_scalar_type = rewriter->getIntegerType(64); - for (auto element : llvm::enumerate(operand_type.getShape())) { - if (element.value() == ShapedType::kDynamicSize) { - Value dim = rewriter->create(loc, operand, element.index()); - shape_values.push_back( - rewriter->create(loc, dim, shape_scalar_type)); - } else { - shape_values.push_back(rewriter->create( - loc, rewriter->getI64IntegerAttr(element.value()))); - } - } - return rewriter->create( - loc, RankedTensorType::get({operand_type.getRank()}, shape_scalar_type), - shape_values); - } -}; - -} // namespace impl - -// Default template to cover HLO operations whose shape derivation is unknown. -template -struct ShapeDerivation { - using impl = impl::UnknownShape; -}; - -// Element-wise operations that have the shape of their first operand. - -#define SAME_SHAPE_AS_FIRST_OPERAND(Op) \ - template <> \ - struct ShapeDerivation { \ - using impl = impl::SameShapeAsFirstOperand; \ - }; - -SAME_SHAPE_AS_FIRST_OPERAND(AbsOp) -SAME_SHAPE_AS_FIRST_OPERAND(AddOp) -SAME_SHAPE_AS_FIRST_OPERAND(AndOp) -SAME_SHAPE_AS_FIRST_OPERAND(CeilOp) -SAME_SHAPE_AS_FIRST_OPERAND(CosOp) -SAME_SHAPE_AS_FIRST_OPERAND(DivOp) -SAME_SHAPE_AS_FIRST_OPERAND(ExpOp) -SAME_SHAPE_AS_FIRST_OPERAND(MaxOp) -SAME_SHAPE_AS_FIRST_OPERAND(MinOp) -SAME_SHAPE_AS_FIRST_OPERAND(MulOp) -SAME_SHAPE_AS_FIRST_OPERAND(NegOp) -SAME_SHAPE_AS_FIRST_OPERAND(RemOp) -SAME_SHAPE_AS_FIRST_OPERAND(SubOp) -SAME_SHAPE_AS_FIRST_OPERAND(TanhOp) - -#undef SAME_SHAPE_AS_FIRST_OPERAND - -} // namespace xla_hlo -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 8f955d6944a..7d4b17ef291 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -46,7 +46,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -427,6 +427,42 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { return GetI64ElementsAttr({0, 1}, b); } +//===----------------------------------------------------------------------===// +// MatrixBandPart op utilities. +//===----------------------------------------------------------------------===// + +// Gets the size of the dimension `dim_from_end` from the end of `input`. +// Requires that `input` is a tensor. +static int GetDimensionSizeFromEnd(Value input, int dim_from_end) { + // Note: the verifier enforces that `input` is a ranked tensor. + auto input_type = input.getType().cast(); + auto input_shape = input_type.getShape(); + int dim = (input_shape.size() - 1) - dim_from_end; + return input_shape[dim]; +} + +// Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1` +// have the same size as the last two dimensions of `input` (the second-to-last +// dimension and last dimension, respectively). The element type of the +// outputted RankedTensorType will match the element type of `input`. +// Requires that `input` is a tensor. +static RankedTensorType Get2DTensorType(Value input) { + // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last. + int dim_0 = GetDimensionSizeFromEnd(input, 1); + int dim_1 = GetDimensionSizeFromEnd(input, 0); + auto element_type = input.getType().cast().getElementType(); + return RankedTensorType::get({dim_0, dim_1}, element_type); +} + +// Creates a HLO ConvertOp, converting `input` to have the same element type as +// `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor. +static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input, + Value elem_type_tensor) { + auto element_type = + elem_type_tensor.getType().cast().getElementType(); + return builder->create(loc, input, element_type); +} + //===----------------------------------------------------------------------===// // Pad op utilities. //===----------------------------------------------------------------------===// @@ -1559,6 +1595,82 @@ class ConvertSizeOp : public OpRewritePattern { } }; +static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, + Value *out_lhs, Value *out_rhs, + PatternRewriter *rewriter) { + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); + // The last two dimensions are the matrix row/col dimensions. Don't + // broadcast them. + SmallVector result_batch_shape; + OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2), + rhs_type.getShape().drop_back(2), + result_batch_shape); + auto handle_one_side = [rewriter, &result_batch_shape, loc]( + Value side, RankedTensorType type, + Value *out_side) { + ArrayRef matrix_dims = type.getShape().take_back(2); + auto result_shape = result_batch_shape; + result_shape.append(matrix_dims.begin(), matrix_dims.end()); + auto result_type = + RankedTensorType::get(result_shape, type.getElementType()); + auto shape = rewriter->create( + loc, GetI64ElementsAttr(result_shape, rewriter)); + *out_side = + rewriter->create(loc, result_type, side, shape); + }; + handle_one_side(lhs, lhs_type, out_lhs); + handle_one_side(rhs, rhs_type, out_rhs); +} + +class ConvertBatchMatMulV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { + // TODO(silvasean): Handle adj_x/adj_y + // Should be able to just set the contracting_dimensions attribute + // appropriately. + // For complex types, need to do a complex conjugation. + if (op.adj_x() || op.adj_y()) return matchFailure(); + + Value lhs = op.x(); + Value rhs = op.y(); + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) return matchFailure(); + // TODO(silvasean): Support dynamic shapes. + if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) { + return matchFailure(); + } + + // Broadcast both operands. + BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, + &rewriter); + lhs_type = lhs.getType().cast(); + rhs_type = rhs.getType().cast(); + assert(lhs_type.getRank() == rhs_type.getRank()); + int64_t rank = lhs_type.getRank(); + auto batch_dimensions = GetI64ElementsAttr( + llvm::to_vector<4>(llvm::seq(0, rank - 2)), &rewriter); + auto lhs_contracting_dimensions = + GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter); + auto rhs_contracting_dimensions = + GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter); + auto dimension_numbers = DotDimensionNumbers::get( + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions, + rewriter.getContext()); + rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs, + dimension_numbers, + /*precision_config=*/nullptr); + return matchSuccess(); + } +}; + // Converts the tf.Split op into a series of HLO slice ops when the tensor to be // split has fully static shape and the dimension to split is a constant. // @@ -1890,7 +2002,7 @@ class ConvertStridedSliceGradOp Value grad = op.dy(); Type element_type = grad.getType().cast().getElementType(); - // Perform reshape to undo any new/shrink axies done by strided slice. + // Perform reshape to undo any new/shrink axes done by strided slice. grad = rewriter.create( op.getLoc(), RankedTensorType::get(shape, element_type), grad); @@ -2892,22 +3004,22 @@ class ConvertOneHotOp : public OpRewritePattern { } }; -// Converts InfeedEnqueueTuple to XLA HLO after_all, infeed and +// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and // get_tuple_element ops. // // All HLO infeed ops expect a HLO token type operand and produce a tuple // containing a token. This HLO token type is used to order multiple infeed // operations within a computation. The token type can come from other -// infeed/outfeed/send/recv ops or can be generated using an after_all op with -// no operands. Here we emit an after_all op to generate the token type operand -// of infeed. +// infeed/outfeed/send/recv ops or can be generated using create_token op with +// no operands. Here we emit a create_token op to generate the token type +// operand of infeed. // // For example the following IR: // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) // // would be lowered to // -// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token +// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token // %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} : // (!xla_hlo.token) -> tuple, tensor<4xf32>>, // !xla_hlo.token> @@ -2926,21 +3038,20 @@ class ConvertInfeedDequeueTupleOp for (auto idx_and_output : llvm::enumerate(op.outputs())) { result_types[idx_and_output.index()] = (idx_and_output.value().getType()); } - // Infeed takes a single token operand. Generate the token using after_all - // op to pass to the infeed op. - auto afterall = rewriter.create( - op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()), - ValueRange()); + // Infeed takes a single token operand. Generate the token using + // create_token op to pass to the infeed op. + auto token = rewriter.create( + op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext())); // Emit infeed op. // The result type of infeed is a tuple(tuple(result types), token type). auto data_tuple_type = mlir::TupleType::get(result_types, rewriter.getContext()); auto data_and_token_type = mlir::TupleType::get( - {data_tuple_type, afterall.getType()}, rewriter.getContext()); + {data_tuple_type, token.getType()}, rewriter.getContext()); auto data_and_token = - rewriter.create(op.getLoc(), data_and_token_type, afterall, + rewriter.create(op.getLoc(), data_and_token_type, token, /*infeed_config=*/rewriter.getStringAttr("")); // The infeed instruction produces a tuple of the infeed data and a token @@ -2962,10 +3073,11 @@ class ConvertInfeedDequeueTupleOp } }; -// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops. +// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed +// ops. // // XLA HLO outfeed op expects a token, which we generate by emitting an -// after_all op. +// create_token op. // // For example the following IR: // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> @@ -2975,7 +3087,7 @@ class ConvertInfeedDequeueTupleOp // // %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> // tuple, tensor<4xf32>> -// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token +// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token // %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} : // (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token // @@ -2988,9 +3100,8 @@ class ConvertOutfeedEnqueueTupleOp PatternRewriter &rewriter) const override { auto token_type = xla_hlo::TokenType::get(rewriter.getContext()); auto tuple = rewriter.create(op.getLoc(), op.inputs()); - auto afterall = - rewriter.create(op.getLoc(), token_type, ValueRange()); - rewriter.create(op.getLoc(), token_type, tuple, afterall, + auto token = rewriter.create(op.getLoc(), token_type); + rewriter.create(op.getLoc(), token_type, tuple, token, /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); return matchSuccess(); @@ -3520,10 +3631,13 @@ class ConvertXlaDynamicUpdateSliceOp PatternMatchResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, PatternRewriter &rewriter) const override { auto indices_type = op.indices().getType().dyn_cast(); - if (!indices_type) return matchFailure(); + if (!indices_type || !indices_type.hasStaticShape() || + indices_type.getShape().size() != 1) + return matchFailure(); - SmallVector unpacked_indices_type( - 2, RankedTensorType::get({}, indices_type.getElementType())); + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), + RankedTensorType::get({}, indices_type.getElementType())); auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, op.indices(), IntegerAttr::get(rewriter.getIntegerType(64), 0)); @@ -3533,6 +3647,80 @@ class ConvertXlaDynamicUpdateSliceOp } }; +/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting +/// appropriate window dimensions, with 'add' as the reduction function. The +/// input tensor needs to have a static shape, and 'axis' must be const. The +/// TableGen pattern is not used for this rewrite because it involves regions. +class ConvertCumsumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::CumsumOp op, + PatternRewriter &rewriter) const override { + auto input = op.x(); + auto input_type = input.getType().dyn_cast(); + if (!input_type || !input_type.hasStaticShape()) { + return matchFailure(); + } + + // TODO(jennik): Add support for the optional 'exclusive' and 'reverse' + // arguments. + if (op.exclusive() || op.reverse()) { + return matchFailure(); + } + + // We can only match when the axis is a constant scalar. + DenseIntElementsAttr axis_attr; + if (!matchPattern(op.axis(), m_Constant(&axis_attr))) { + return matchFailure(); + } + + // Convert if we need to enlarge the element type's bitwidth to avoid + // precision loss. + Type input_element_type = input_type.getElementType(); + Type sum_element_type = GetSumAccumulationType(input_element_type); + input = rewriter.create(op.getLoc(), input, sum_element_type); + + ArrayRef input_shape = input_type.getShape(); + int64_t rank = input_shape.size(); + + // Get the dimension to apply the reduction on, and offset properly if it is + // negative. + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < 0) { + axis += rank; + } + + SmallVector window_dims(rank, 1); + SmallVector window_strides(rank, 1); + window_dims[axis] = input_shape[axis]; + + SmallVector paddings(rank * 2, 0); + paddings[axis * 2] = input_shape[axis] - 1; + auto paddings_attr = DenseIntElementsAttr::get( + RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)), + paddings); + + Value init = + GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + + auto reduce = rewriter.create( + op.getLoc(), input_type, input, init, + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); + Value result = reduce.getResult(); + + // Convert back if we enlarged the element type's bitwidth. + result = + rewriter.create(op.getLoc(), result, input_element_type); + + rewriter.replaceOp(op, result); + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { @@ -3547,9 +3735,9 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { // here for lowering to HLO. TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< - ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBF16FloorDivOp, - ConvertConv2D, ConvertConv2DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertEinsumOp, + ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, + ConvertBF16FloorDivOp, ConvertConv2D, ConvertConv2DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 519ba9235f1..b9599201601 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -193,7 +193,7 @@ def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), // input and result needs to ranked for computation of the broadcast dimensions. def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape), (HLO_BroadcastInDimOp $input, - (BinBroadcastDimensions $input, $result)), + (BinBroadcastDimensionsNonEmpty $input, $result)), [(AnyRankedTensor $result)]>; //===----------------------------------------------------------------------===// @@ -357,6 +357,89 @@ def SparseMatMulToMatMul : Pat<(TF_SparseMatMulOp $a, $b, $a_sparse, $b_sparse, (TF_MatMulOp $a, $b, $transpose_a, $transpose_b)>; +//===----------------------------------------------------------------------===// +// MatrixBandPart op pattern. +//===----------------------------------------------------------------------===// + +class getIntegerAttr: NativeCodeCall< + "$_builder.getI64IntegerAttr(" # x # ")">; + +class GetDimensionSizeFromEnd: NativeCodeCall< + "$_builder.getI64IntegerAttr(GetDimensionSizeFromEnd($0, " # dimFromEnd # "))" + >; + +// TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern +// For now, this op needs to be created in C++ because the expected output type +// cannot be inferred. +class createIotaOp: NativeCodeCall< + "$_builder.create($0.getOwner()->getLoc(), " + "Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">; + +// This op needs to be created in C++ because the generated Convert Op has no +// way to specify shape information as an input. In the MatrixBandPart op +// lowering, ConvertOp is not a root operation and the appropriate types cannot +// be inferred, so we construct it manually. +def createConvertOp: NativeCodeCall< + "CreateConvertOp(&($_builder), $0.getOwner()->getLoc(), $1, $2)">; + +// Performs a substitution of MatrixBandPartOp for XLA HLO ops. Psuedocode is +// shown below, given a tensor `input` with k dimensions [I, J, K, ..., M, N] +// and two integers, `num_lower` and `num_upper`: +// +// iota_m = { M x N matrix with 0,1,...M along the M dimension } +// iota_n = { M x N matrix with 0,1,...N along the N dimension } +// num_lower_or_m = (num_lower < 0) ? m : num_lower +// num_upper_or_n = (num_upper < 0) ? n : num_upper +// offset = iota_m - iota_n +// indicator = (-num_lower_or_m < offset) & (offset < num_upper_or_n) +// zero_matrix = { [I, J, K,...M, N] zero matrix } +// return (indicator ? input : zero_matrix) +// +// TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp. +def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_upper), + [(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"0"> $input)), + (HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"1"> $input)), + (HLO_SelectOp:$num_lower_or_m + (HLO_CompareOp + $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + ), + $m_dim, + $num_lower + ), + (HLO_SelectOp:$num_upper_or_n + (HLO_CompareOp + $num_upper, $zero, + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + ), + $n_dim, + $num_upper + ), + (HLO_SelectOp + (HLO_AndOp + (HLO_CompareOp + (HLO_NegOp + (createConvertOp $op, $num_lower_or_m, $input) + ), + (HLO_SubOp:$offset + (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input), + (NullDenseIntElementsAttr) + ), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE + ), + (HLO_CompareOp + $offset, + (createConvertOp + $op, $num_upper_or_n, $input + ), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE + ), + (BinBroadcastDimensions $offset, $input) + ), + $input, + (HLO_ConstOp (ConstantSplat<"0"> $input)) + )]>; + //===----------------------------------------------------------------------===// // Nullary op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc new file mode 100644 index 00000000000..962bf97c44d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -0,0 +1,389 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Optional.h" +#include "mlir/IR/Diagnostics.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_properties.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace mlir { +namespace xla_hlo { +namespace { + +template +using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok + +static bool IsOpWhitelisted(Operation* op) { + // White-listed TensorFlow ops are known to have well behaved tf2xla kernels + // building valid MLIR using MlirHloBuilder. + // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for + // all tf2xla kernels. + return isa(op) || isa(op) || + isa(op) || isa(op); +} + +static llvm::Optional GetJitDevice( + const std::string& device_type, const Location& loc) { + if (device_type == "XLA_CPU") return absl::string_view("XLA_CPU_JIT"); + if (device_type == "TPU") return absl::string_view("XLA_TPU_JIT"); + // TODO(hinsu): Support GPU device along with a test for it. + + emitError(loc) << "unsupported device for legalization with tf2xla kernels: " + << device_type; + return llvm::None; +} + +static std::unique_ptr CreateDeviceMgr( + const std::string& device_type, const Location& loc) { + auto jit_device_or = GetJitDevice(device_type, loc); + if (!jit_device_or) return nullptr; + + auto* factory = tensorflow::DeviceFactory::GetFactory(device_type); + if (!factory) { + emitError(loc) << "failed to create DeviceFactory for device: " + << device_type; + return nullptr; + } + std::vector> devices; + auto status = factory->CreateDevices( + tensorflow::SessionOptions(), + /*name_prefix=*/"/job:localhost/replica:0/task:0", &devices); + if (!status.ok()) { + emitError(loc) << status.ToString(); + return nullptr; + } + + auto device = absl::make_unique( + tensorflow::SessionOptions(), tensorflow::DeviceType(*jit_device_or)); + return absl::make_unique(std::move(device)); +} + +class FuncLegalizer { + public: + static LogicalResult Legalize(FuncOp func, const std::string& device_type) { + FuncLegalizer legalizer(func, device_type); + if (failed(legalizer.PrepareParams())) return failure(); + return legalizer.Legalize(); + } + + private: + FuncLegalizer(FuncOp func, const std::string& device_type) + : func_(func), device_type_(device_type), hlo_builder_(func) {} + + ~FuncLegalizer() { context_->Unref(); } + + // Prepares OpKernelContext params common to all the ops. + // Emits an error on failure. + LogicalResult PrepareParams(); + + // Tries to legalize supported TensorFlow ops. + // Emits an error on failure. + LogicalResult Legalize(); + + // Tries to legalize the specified TensorFlow op, if supported. + // + // Emits an error and returns failure if an error is encountered during + // conversion. Note that success return value doesn't mean successful + // legalization. + LogicalResult LegalizeOp(Operation* op); + + FuncOp func_; + std::string device_type_; + + ::xla::MlirHloBuilder hlo_builder_; + tensorflow::OpOrArgLocNameMapper name_mapper_; + + tensorflow::XlaContext* context_; // Ref-counted. + + std::unique_ptr device_mgr_; + tensorflow::Device* device_; // Owned by device_mgr_; + std::unique_ptr step_container_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + tensorflow::OpKernelContext::Params params_; +}; + +LogicalResult FuncLegalizer::PrepareParams() { + // XlaCompiler within the context is only used by the functional ops to + // compile functions. We are not handling those at the moment so XlaCompiler + // is not required. + context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_); + context_->Ref(); + + mlir::Location loc = func_.getLoc(); + device_mgr_ = CreateDeviceMgr(device_type_, loc); + if (!device_mgr_) return failure(); + + // Type of params_.device is DeviceBase* so store it as Device* to access + // derived class method. + device_ = device_mgr_->ListDevices().front(); + params_.device = device_; + params_.resource_manager = device_->resource_manager(); + + // Resources are cleared at the time of device manager destruction so pass + // no-op cleanup function. + auto cleanup = [](const std::string& name) {}; + // Use step_id zero as we only have a single context concurrently and + // concurrently running each of the MLIR functions create a new device. + step_container_ = absl::make_unique( + /*step_id=*/0, cleanup); + tensorflow::Status status = step_container_->Create( + device_->resource_manager(), + tensorflow::XlaContext::kXlaContextResourceName, context_); + if (!status.ok()) { + emitError(loc) << "failed to create XlaContext resource: " + << status.ToString(); + return failure(); + } + params_.step_container = step_container_.get(); + + tensorflow::StatusOr version_or = + tensorflow::GetTfGraphProducerVersion( + func_.getParentOfType()); + if (!version_or.ok()) { + emitError(loc) << version_or.status().ToString(); + return failure(); + } + + flib_def_ = absl::make_unique( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + pflr_ = absl::make_unique( + device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr, + version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions()); + params_.function_library = pflr_->GetFLR(device_->name()); + return success(); +} + +LogicalResult FuncLegalizer::Legalize() { + // TensorFlow functions don't use CFGs. + if (func_.getBlocks().size() > 1) { + emitError(func_.getLoc()) << "requires at most one block in a TF function"; + return failure(); + } + if (func_.getBlocks().empty()) return success(); + Block& block = func_.getBlocks().front(); + + std::vector ops; + ops.reserve(block.getOperations().size()); + for (Operation& op : block.getOperations()) { + ops.push_back(&op); + } + + for (Operation* op : ops) { + if (failed(LegalizeOp(op))) return failure(); + } + return success(); +} + +LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { + if (!IsOpWhitelisted(op)) return success(); + + // Only static shaped operands are supported in XLA builders for now. + for (Type ty : op->getOperandTypes()) { + auto ranked_ty = ty.cast(); + if (!ranked_ty || !ranked_ty.hasStaticShape()) { + op->emitRemark() << "lowering requires static shaped operands"; + return success(); + } + } + + auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( + op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true); + if (!nodedef_or.ok()) { + op->emitRemark() << "failed to convert op to NodeDef: " + << nodedef_or.status().ToString(); + return success(); + } + + std::shared_ptr props; + tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef( + *nodedef_or.ValueOrDie(), + params_.function_library->GetFunctionLibraryDefinition(), &props); + if (!status.ok()) { + op->emitRemark() << "failed to create NodeProperties: " + << status.ToString(); + return success(); + } + tensorflow::OpKernel* op_kernel_raw; + status = params_.function_library->CreateKernel(props, &op_kernel_raw); + if (!status.ok()) { + op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString(); + return success(); + } + // Transfer ownership of the kernel to a local smart pointer. + auto op_kernel = absl::WrapUnique(op_kernel_raw); + + // TensorValue in inputs are backed by tensors which in turn depend on + // expressions. So, pre-allocate them to the required size. + InlinedVector expressions; + InlinedVector tensors; + InlinedVector inputs; + expressions.reserve(op->getNumOperands()); + tensors.reserve(op->getNumOperands()); + inputs.reserve(op->getNumOperands()); + + // Prepare the list of Tensor inputs for the kernel. + for (Value operand : op->getOperands()) { + // Skip this op if XLA doesn't support this operand type. + auto xla_op_or = hlo_builder_.MakeXlaOp(operand); + if (!xla_op_or.ok()) { + op->emitRemark() << "skipping legalization due to " + << xla_op_or.status().ToString(); + return success(); + } + ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); + + tensorflow::DataType dtype; + status = tensorflow::ConvertToDataType(operand.getType(), &dtype); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to " << status.ToString(); + return success(); + } + + auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype); + expressions.push_back(expression); + + if (!tensorflow::DataTypeCanUseMemcpy(dtype)) { + op->emitRemark() << "skipping legalization due to unsupported type " + << operand.getType(); + return success(); + } + + auto shape_or = expression.GetShape(); + if (!shape_or.ok()) { + op->emitRemark() << "failed to get shape for expression. " + << expression.HumanString(); + return success(); + } + + tensors.emplace_back( + device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype, + shape_or.ValueOrDie()); + tensorflow::Tensor& tensor = tensors.back(); + tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression, + &tensor); + inputs.emplace_back(&tensor); + } + + params_.inputs = &inputs; + params_.op_kernel = op_kernel.get(); + llvm::SmallVector output_attr( + op->getNumResults()); + params_.output_attr_array = output_attr.data(); + + hlo_builder_.setInsertionPoint(op); + hlo_builder_.SetLocation(op->getLoc()); + + // Execute the kernel. + tensorflow::OpKernelContext op_context(¶ms_, op->getNumResults()); + device_->Compute(params_.op_kernel, &op_context); + if (!op_context.status().ok()) { + op->emitRemark() << "compilation to HLO failed: " + << op_context.status().ToString(); + return success(); + } + + // Replace uses of old results using the corresponding value after the + // lowering. + for (int i = 0, e = op->getNumResults(); i < e; i++) { + tensorflow::Tensor* output = op_context.mutable_output(i); + const tensorflow::XlaExpression* expr = + tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output); + if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp) + return op->emitError( + "expects XlaExpression of kind kXlaOp in compiled output"); + auto value = hlo_builder_.GetValue(expr->handle()); + op->getResult(i).replaceAllUsesWith(value); + } + + op->erase(); + return success(); +} + +class LegalizeTF : public FunctionPass { + public: + LegalizeTF() = default; + + LegalizeTF(const LegalizeTF&) {} + + void runOnFunction() override { + if (failed(FuncLegalizer::Legalize(getFunction(), device_type_))) + signalPassFailure(); + } + + private: + // TODO(hinsu): Support finer grained device type assignment instead of a + // global device type for all TensorFlow ops. + Option device_type_{ + *this, "device-type", + llvm::cl::desc("XLA device type for execution of TensorFlow ops. " + "Supports XLA_CPU and TPU for now.")}; +}; + +static PassRegistration pass( + "xla-legalize-tf-with-tf2xla", + "Legalize from TensorFlow to the HLO dialect using tf2xla kernels"); + +} // end namespace + +} // end namespace xla_hlo +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 1c0f3d8f242..aeaceeb27d5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -87,20 +87,20 @@ class CompareFConvert : public OpRewritePattern { return matchFailure(); auto comparison_direction = op.comparison_direction(); - CmpFPredicate compare_predicate = - llvm::StringSwitch(comparison_direction) + auto compare_predicate = + llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::UNE) .Case("LT", CmpFPredicate::OLT) .Case("LE", CmpFPredicate::OLE) .Case("GT", CmpFPredicate::OGT) .Case("GE", CmpFPredicate::OGE) - .Default(CmpFPredicate::NumPredicates); + .Default(llvm::None); - if (compare_predicate == CmpFPredicate::NumPredicates) - return matchFailure(); + if (!compare_predicate.hasValue()) return matchFailure(); - rewriter.replaceOpWithNewOp(op, compare_predicate, lhs, rhs); + rewriter.replaceOpWithNewOp(op, compare_predicate.getValue(), lhs, + rhs); return matchSuccess(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc new file mode 100644 index 00000000000..86125126390 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_copy_removal.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements a pass to remove redundant LHLO copy operations. + +#include "absl/memory/memory.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Removes LHLO copy operations that copy from allocated buffers to block +// arguments. All uses of each buffer are replaced with the corresponding block +// argument and the buffer is freed. Note that this pass only works in regions +// with a single block. +struct LhloCopyRemoval : mlir::OperationPass { + void runOnOperation() override { + llvm::SmallVector eraseList; + auto operation = getOperation(); + operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + // If this region contains more than one block, then ignore this copy + // operation. + if (copyOp.getParentRegion()->getBlocks().size() > 1) { + return; + } + + mlir::Value fromOperand = copyOp.operand(); + mlir::Value toOperand = copyOp.output(); + + // If the fromOperand value is a block argument or the toOperand + // value is not a block argument, then ignore this copy operation. + if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) { + return; + } + + // The copy operation removal is illegal if there is at least a single use + // of toOperand value that lies between the first use of fromOperand value + // and the copy operation. + auto fromOperandUsers = fromOperand.getUsers(); + auto firstUser = *fromOperandUsers.begin(); + for (auto op : fromOperandUsers) { + if (op->isBeforeInBlock(firstUser)) firstUser = op; + } + for (auto op : toOperand.getUsers()) { + if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) { + return; + } + } + + // TODO(DFKI): Use live variable analysis to solve aliasing issues among + // block arguments. + + // Remove the associated alloc operation. + auto allocOp = fromOperand.getDefiningOp(); + eraseList.push_back(allocOp); + + // Iterate over all uses of the fromOperand to find the associated + // deallocOp (if any). + for (auto op : fromOperandUsers) { + if (isa(op)) { + eraseList.push_back(op); + break; + } + } + + // Replace all uses of the fromOperand with the toOperand. This rewires + // all references pointing to the original alloc operation to the new + // target operation in order to safely remove the copy op. + fromOperand.replaceAllUsesWith(toOperand); + copyOp.erase(); + }); + for (auto op : eraseList) { + op->erase(); + } + }; +}; + +} // namespace + +std::unique_ptr createLhloCopyRemovalPass() { + return absl::make_unique(); +} + +static PassRegistration copy_removal_pass( + "lhlo-copy-removal", "Removes redundant LHLO copy operations"); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index 8f34034d6d3..a27a27b3760 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -63,8 +63,7 @@ class LhloFuseLinalg : public FunctionPass { SmallVector tile_sizes(tile_sizes_.begin(), tile_sizes_.end()); if (tile_sizes.empty()) { - tile_sizes = - SmallVector(generic_op.getNumInputsAndOutputs(), 1); + tile_sizes = SmallVector(generic_op.getNumLoops(), 1); } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 2c550465302..32053950fed 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -31,11 +31,11 @@ namespace mlir { namespace xla_lhlo { namespace { -template -struct BinaryOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct BinaryOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LhloOp op, + PatternMatchResult matchAndRewrite(LhloOpTy op, PatternRewriter& rewriter) const override { const auto& lhs = op.lhs(); const auto& rhs = op.rhs(); @@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern { } auto l = rewriter.create(loc, lhs, induction_vars); auto r = rewriter.create(loc, rhs, induction_vars); - Value opResult = MapXlaOpToStdScalarOp( - llvm::cast(op), element_type, {l, r}, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r}, &rewriter); if (opResult == nullptr) { return this->matchFailure(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc new file mode 100644 index 00000000000..8ef08e4f9f3 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -0,0 +1,244 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project +#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. +// The outper `ParallelOp` refers to the parallel loops if there are +// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` +// contains the reduction operator. +// +// Example: +// +// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { +// ^bb0(%lhs: memref, %rhs: memref, %res: memref): +// +// } ) {dimensions = dense<[1]> : tensor<1xi64>} +// : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () +// +// is converted into: +// +// %init = load %init_buf[] : memref +// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { +// %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { +// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> +// loop.reduce(%elem_to_reduce) { +// ^bb0(%elem: f32, %acc: f32): // no predecessors +// elem_buf = alloc() : memref +// store %elem, elem_buf[] : memref +// acc_buf = alloc() : memref +// store %acc, acc_buf[] : memref +// +// %acc_result = load acc_buf[] : memref +// loop.reduce.return %acc_result : f32 +// } : f32 +// loop.yield +// } : f32 +// loop.yield +// } +class ReduceOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + // TODO(b/137624192) Implement variadic reduce. + if (xla_reduce_op.out().size() != 1) return matchFailure(); + + loop::ReduceOp reduce_op = + CreateParallelLoopsWithReduceOp(xla_reduce_op, args, &rewriter); + ConvertReductionOperator(xla_reduce_op, + &reduce_op.reductionOperator().front(), &rewriter); + rewriter.replaceOp(xla_reduce_op, llvm::None); + return matchSuccess(); + } + + private: + // Creates nested `loop.parallel` ops with `loop.reduce`. The outer ParallelOp + // refers to the parallel dimensions of `xla_reduce_op` if any and the inner + // ParallelOp refers to the reduction dimensions. The loop.reduce op is + // returned. + // + // If the reduction argument is a memref<100x10x5xf32> and the + // reduction is performed along dimension 1 then this method will generate + // + // %init = load %init_buf[] : memref + // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { + // %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { + // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> + // loop.reduce(%elem_to_reduce) { + // + // } : f32 + // loop.yield + // } : f32 + // loop.yield + // } + loop::ReduceOp CreateParallelLoopsWithReduceOp( + xla_lhlo::ReduceOp xla_reduce_op, ArrayRef args, + ConversionPatternRewriter* rewriter) const { + auto loc = xla_reduce_op.getLoc(); + DenseSet reducing_dims; + for (auto rdim : xla_reduce_op.dimensions().getIntValues()) { + reducing_dims.insert(rdim.getSExtValue()); + } + + Value operand = *xla_reduce_op.operands().begin(); + Value out = *xla_reduce_op.out().begin(); + SmallVector parallel_lower, parallel_upper, parallel_step; + SmallVector reduce_lower, reduce_upper, reduce_step; + auto operand_shape = operand.getType().cast().getShape(); + Type index_type = rewriter->getIndexType(); + for (auto dim : llvm::enumerate(operand_shape)) { + const bool is_reducing_dim = reducing_dims.count(dim.index()); + + Value ub = + dim.value() == ShapedType::kDynamicSize + ? rewriter->create(loc, operand, dim.index()).getResult() + : rewriter->create( + loc, index_type, + rewriter->getIntegerAttr(index_type, dim.value())); + Value lb = rewriter->create( + loc, index_type, rewriter->getIntegerAttr(index_type, 0)); + Value step = rewriter->create( + loc, index_type, rewriter->getIntegerAttr(index_type, 1)); + (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb); + (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub); + (is_reducing_dim ? reduce_step : parallel_step).push_back(step); + } + // Load initial value from memref. + SmallVector init_value = { + rewriter->create(loc, *xla_reduce_op.init_values().begin())}; + // Outer ParallelOp is not needed if it is a reduction across all dims. + loop::ParallelOp outer; + if (!parallel_lower.empty()) { + outer = rewriter->create(loc, parallel_lower, + parallel_upper, parallel_step); + rewriter->setInsertionPointToStart(outer.getBody()); + } + loop::ParallelOp inner = rewriter->create( + loc, reduce_lower, reduce_upper, reduce_step, init_value); + Value reduction_result = *inner.getResults().begin(); + + SmallVector out_indices; + if (outer != nullptr) { + out_indices.reserve(outer.getNumLoops()); + for (auto& iv : outer.getInductionVars()) { + out_indices.push_back(iv); + } + } else { + out_indices.push_back(rewriter->create( + loc, index_type, rewriter->getIntegerAttr(index_type, 0))); + } + + rewriter->create(loc, reduction_result, out, out_indices); + + // Load the element to reduce. + SmallVector indices; + indices.reserve(operand_shape.size()); + Block::args_iterator outer_ivs_it = + outer ? outer.getInductionVars().begin() : nullptr; + Block::args_iterator inner_ivs_it = inner.getInductionVars().begin(); + for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { + indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ + : *outer_ivs_it++); + } + + rewriter->setInsertionPointToStart(inner.getBody()); + Value elem = rewriter->create( + loc, *xla_reduce_op.operands().begin(), indices); + return rewriter->create(loc, elem); + } + + // Converts `xla_lhlo.reduce` reduction operator into `loop.reduce` op by + // doing buffer allocation for scalar arguments and the result of + // `loop.reduce` to make it compatible with LHLO ops. + void ConvertReductionOperator(xla_lhlo::ReduceOp xla_reduce_op, + Block* loop_reduce_op_body, + ConversionPatternRewriter* rewriter) const { + rewriter->setInsertionPointToStart(loop_reduce_op_body); + + // Allocate buffers to hold arguments of reduction operator block to stay + // compatible with the LHLO dialect ops in the reduction body. + auto loc = xla_reduce_op.getLoc(); + Value elem_arg = xla_reduce_op.body().front().getArgument(0); + Value elem_buf = + rewriter->create(loc, elem_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body->getArgument(0), + elem_buf); + Value acc_arg = xla_reduce_op.body().front().getArgument(1); + Value acc_buf = + rewriter->create(loc, acc_arg.getType().cast()); + rewriter->create(loc, loop_reduce_op_body->getArgument(1), + acc_buf); + + // Clone the ops from `xla_lhlo.reduce` into reduction operator block. + BlockAndValueMapping mapping; + mapping.map(xla_reduce_op.body().front().getArguments(), + ValueRange{elem_buf, acc_buf, acc_buf}); + for (auto& nested : xla_reduce_op.body().front().without_terminator()) { + auto clone = rewriter->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + Value acc_result = rewriter->create(loc, acc_buf); + rewriter->create(loc, acc_result); + } +}; + +struct LhloLegalizeToParallelLoops + : public FunctionPass { + void runOnFunction() override { + auto func = getFunction(); + + OwningRewritePatternList patterns; + patterns.insert(func.getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + + if (failed(applyPartialConversion(func, target, patterns, nullptr))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { + return absl::make_unique(); +} + +static PassRegistration legalize_lhlo_pass( + "lhlo-legalize-to-parallel-loops", + "Legalize from LHLO dialect to parallel loops."); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h new file mode 100644 index 00000000000..9d04e82430d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ + +#include + +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +template +struct HloToLhloOpImpl { + using Type = std::false_type; +}; +template +using HloToLhloOp = typename HloToLhloOpImpl::Type; + +#define MAP_HLO_TO_LHLO(OpName) \ + template <> \ + struct HloToLhloOpImpl { \ + using Type = xla_lhlo::OpName; \ + } + +MAP_HLO_TO_LHLO(AbsOp); +MAP_HLO_TO_LHLO(AddOp); +MAP_HLO_TO_LHLO(AndOp); +MAP_HLO_TO_LHLO(BroadcastInDimOp); +MAP_HLO_TO_LHLO(CeilOp); +MAP_HLO_TO_LHLO(ConstOp); +MAP_HLO_TO_LHLO(CompareOp); +MAP_HLO_TO_LHLO(ConvertOp); +MAP_HLO_TO_LHLO(CopyOp); +MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(LogOp); +MAP_HLO_TO_LHLO(MaxOp); +MAP_HLO_TO_LHLO(MinOp); +MAP_HLO_TO_LHLO(MulOp); +MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(ReduceOp); +MAP_HLO_TO_LHLO(RemOp); +MAP_HLO_TO_LHLO(RsqrtOp); +MAP_HLO_TO_LHLO(SelectOp); +MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SqrtOp); +MAP_HLO_TO_LHLO(SubOp); +MAP_HLO_TO_LHLO(TanhOp); + +#undef MAP_HLO_TO_LHLO + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 6554942954e..40add223156 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -21,81 +21,63 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" namespace mlir { namespace xla_lhlo { +namespace impl { -template -struct ScalarOp; +// A struct to map LhloBinaryOpTy type to the corresponding floating-point and +// integer scalar operation types. +template +struct LhloToScalarOp; template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::AddFOp; - using IOp = ::mlir::AddIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::CmpFOp; - using IOp = ::mlir::CmpIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::DivFOp; - using IOp = ::mlir::SignedDivIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::MulFOp; - using IOp = ::mlir::MulIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::RemFOp; - using IOp = ::mlir::SignedRemIOp; -}; -template <> -struct ScalarOp { - using FOp = ::mlir::SubFOp; - using IOp = ::mlir::SubIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; }; -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; +template +struct ScalarOp { + using FOp = typename LhloToScalarOp::FOp; + using IOp = typename LhloToScalarOp::IOp; +}; + +// Alias for the map from LHLO binary op type to STD floating-point op type. +template +using ScalarFOp = typename ScalarOp::FOp; +// Alias for the map from LHLO binary op type to STD integer op type. +template +using ScalarIOp = typename ScalarOp::IOp; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return nullptr; @@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl { }; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return b->template create(loc, result_types, args, mlir::None); @@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl { }; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); @@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl { return b->template create(loc, result_types, args, mlir::None); } - return MapXlaOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } }; -template -inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl, FloatType, - ScalarFOp>{}(xla_op.getLoc(), - result_types, args, b); -} - -// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO -// specialization. -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::AbsOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::AbsOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +// Inserts the computation that corresponds to the body of the loop for lowered +// LHLO unary/binary op. Returns the value for the result. +template +inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl, FloatType, + ScalarFOp>{}(loc, result_types, + args, b); } template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::AndOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } + template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::AndOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template @@ -176,14 +140,14 @@ inline Optional getCmpPredicate( template <> inline Optional getCmpPredicate( StringRef xla_comparison_direction) { - return llvm::StringSwitch(xla_comparison_direction) + return llvm::StringSwitch>(xla_comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::ONE) .Case("GE", CmpFPredicate::OGE) .Case("GT", CmpFPredicate::OGT) .Case("LE", CmpFPredicate::OLE) .Case("LT", CmpFPredicate::OLT) - .Default(CmpFPredicate::NumPredicates); + .Default(llvm::None); } template <> @@ -200,7 +164,8 @@ inline Optional getCmpPredicate( } template -inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, +inline Value MapXlaCompareOpToStdScalarOp(Location loc, + StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; @@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, Type element_type = lhs.getType(); if (element_type.isSignlessInteger()) { Optional predicate = - getCmpPredicate(xla_op.comparison_direction()); + getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(xla_op.getLoc(), - predicate.getValue(), lhs, rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } if (element_type.isa()) { Optional predicate = - getCmpPredicate(xla_op.comparison_direction()); + getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(xla_op.getLoc(), - predicate.getValue(), lhs, rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } return nullptr; } -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CompareOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaCompareOpToStdScalarOp(xla_op, result_types, - args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp( - xla_hlo::CompareOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaCompareOpToStdScalarOp(xla_op, result_types, - args, b); -} template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CopyOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return args.front(); } -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CopyOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return args.front(); -} template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::ExpOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::ExpOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CeilOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CeilOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::ConvertOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { Type sourceType = args.front().getType(); Type targetType = result_types.front(); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (sourceType.isa() && targetType.isa()) { FloatType src = sourceType.cast(); FloatType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (src.getWidth() < res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } // No conversion is needed for the same width floats return args.front(); @@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp( IntegerType src = sourceType.cast(); IntegerType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (src.getWidth() < res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, + return b->create(loc, result_types, args, mlir::None); } // No conversion is needed for the same width integers @@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp( } // TODO(dfki-ehna): Add other primitive type conversions // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { - // return b.create(xla_op.getLoc(), result_types, + // return b.create(loc, result_types, // args,mlir::None); // } - return nullptr; } template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::CosOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CosOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } /// Implements the conversion of XLA op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template -struct MapXlaCompareSelectOpToStdScalarOp { - Value operator()(Location loc, StringRef comparison_direction, +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return nullptr; @@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp { /// dialect with a given predicate based on the element type of the operand. template -struct MapXlaCompareSelectOpToStdScalarOp { - Value operator()(Location loc, StringRef comparison_direction, +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); @@ -374,117 +287,142 @@ struct MapXlaCompareSelectOpToStdScalarOpcreate<::mlir::SelectOp>(loc, cmp, args[0], args[1]); } - return MapXlaCompareSelectOpToStdScalarOp{}( + return XlaCompareSelectOpToStdScalarOp::map( loc, comparison_direction, result_types, args, b); } }; template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::MaxOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "GT", - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::MaxOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "GT", - result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::MinOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "LT", - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::MinOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "LT", - result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::NegOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::NegOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::SelectOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp( - xla_hlo::SelectOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), - result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::SignOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { FloatType float_type = element_type.cast(); APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); - Value one = b->create(xla_op.getLoc(), const_value, - float_type); - return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one, - args[0]); + Value one = b->create(loc, const_value, float_type); + return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); } return nullptr; } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::TanhOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } + template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::TanhOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } +} // namespace impl + +struct XlaOpToStdScalarOp { + // Implementation for LHLO ops except xla_lhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for HLO ops except xla_hlo::CompareOp. + template , + typename = std::enable_if_t< + !std::is_same::value && + !std::is_same::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, int i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for xla_lhlo::CompareOp. + template ::value>> + static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } + + // Implementation for xla_hlo::CompareOp. + template ::value>> + static Value map(xla_hlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } +}; + } // namespace xla_lhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 8c0ed08fb66..b1afd543c2e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -29,6 +29,7 @@ class ModuleOp; class Operation; template class OpPassBase; +class Pass; namespace xla_hlo { @@ -59,11 +60,6 @@ std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); -// Removes unnecessary LHLO copies which copy from the allocated buffers to the -// block arguments. These copies have been created by replacing TensorStoreOp -// with LHLO.CopyOp in HLO to LHLO lowering. -std::unique_ptr> createLhloCopyRemovalPass(); - } // namespace xla_hlo namespace xla_lhlo { @@ -89,6 +85,15 @@ std::unique_ptr> createLegalizeToGpuPass(); std::unique_ptr> createLhloFuseLinalg( bool use_parallel_loops = false, ArrayRef tile_sizes = {}); +// Removes unnecessary LHLO copies which copy from the allocated buffers to the +// block arguments. The block arguments are used instead of all uses of these +// buffers. The buffers are freed. This pass only works in regions that contain +// a single block. +std::unique_ptr createLhloCopyRemovalPass(); + +// Lowers from LHLO dialect to parallel loops. +std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); + } // namespace xla_lhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 6447c5d6c3f..071cc575656 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project @@ -28,20 +30,47 @@ namespace xla_hlo { namespace { -// Broadcasts the 1D value tensor to rank. -Value broadcastToFeatureDim(Location loc, Type result_type, Value value_1d, +// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If +// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates +// a static broadcast. +Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type, + Value value_1d, Value shape_value, int64_t feature_dim, - ConversionPatternRewriter& rewriter) { + ConversionPatternRewriter& rewriter) { // NOLINT Builder b(rewriter.getContext()); auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); + if (shape_value) { + return rewriter.createOrFold( + loc, result_type, value_1d, shape_value, dims); + } + assert(result_type.hasStaticShape()); return rewriter.create(loc, result_type, value_1d, dims); } +// Calculate the shape value of operand, assuming it is a dynamic shape with +// static rank. +Value CalculateShapeValue(Location loc, Value operand, + ConversionPatternRewriter& rewriter) { // NOLINT + RankedTensorType result_type = operand.getType().dyn_cast(); + llvm::SmallVector shape_values; + int64_t rank = result_type.getRank(); + shape_values.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + auto index_value = rewriter.create(loc, operand, i); + shape_values.push_back(rewriter.create( + loc, index_value, rewriter.getIntegerType(32))); + } + Type shape_element_type = shape_values.front().getType(); + return rewriter.create( + loc, RankedTensorType::get({rank}, shape_element_type), shape_values); +} + Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, - FloatType fp_type, Type broadcast_to_type, - ConversionPatternRewriter& rewriter) { + FloatType fp_type, Value variance, + RankedTensorType broadcast_to_type, + ConversionPatternRewriter& rewriter) { // NOLINT Builder b(rewriter.getContext()); if (epsilon_attr.getType() != fp_type) { // Need to convert. @@ -66,9 +95,16 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); Value epsilon = rewriter.create(op->getLoc(), epsilon_tensor_attr); - epsilon = rewriter.create( - op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/nullptr); - return epsilon; + auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64)); + auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); + if (broadcast_to_type.hasStaticShape()) { + return rewriter.create( + op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims); + } + Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter); + return rewriter.createOrFold( + op->getLoc(), broadcast_to_type, epsilon, shape_value, + /*broadcast_dims=*/dims); } class UnfuseBatchNormInferencePattern @@ -84,9 +120,10 @@ class UnfuseBatchNormInferencePattern // Enforce type invariants. // Note that we deduce the actual element type from the variance, // which should not be subject to quantization at a higher level. - auto input_type = operands.operand().getType(); - auto variance_type = operands.variance().getType().dyn_cast(); - if (!variance_type) { + auto input_type = operands.operand().getType().dyn_cast(); + auto variance_type = + operands.variance().getType().dyn_cast(); + if (!input_type || !variance_type) { return matchFailure(); } auto fp_type = variance_type.getElementType().dyn_cast(); @@ -97,8 +134,9 @@ class UnfuseBatchNormInferencePattern // Add epsilon to the variance and sqrt to get stddev: // stddev = sqrt(variance + epsilon) - auto epsilon = MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), - fp_type, variance_type, rewriter); + auto epsilon = + MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type, + operands.variance(), variance_type, rewriter); if (!epsilon) { return matchFailure(); } @@ -108,14 +146,22 @@ class UnfuseBatchNormInferencePattern stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. - auto broadcast_scale = broadcastToFeatureDim( - bn_op.getLoc(), input_type, operands.scale(), feature_dim, rewriter); - auto broadcast_offset = broadcastToFeatureDim( - bn_op.getLoc(), input_type, operands.offset(), feature_dim, rewriter); - auto broadcast_mean = broadcastToFeatureDim( - bn_op.getLoc(), input_type, operands.mean(), feature_dim, rewriter); - auto broadcast_stddev = broadcastToFeatureDim( - bn_op.getLoc(), input_type, stddev, feature_dim, rewriter); + Value shape_value; + if (!input_type.hasStaticShape()) { + shape_value = + CalculateShapeValue(bn_op.getLoc(), operands.operand(), rewriter); + } + auto broadcast_scale = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.scale(), + shape_value, feature_dim, rewriter); + auto broadcast_offset = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.offset(), + shape_value, feature_dim, rewriter); + auto broadcast_mean = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.mean(), + shape_value, feature_dim, rewriter); + auto broadcast_stddev = BroadcastToFeatureDim( + bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter); // Compute: // scale * (input - mean) / stddev + offset diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc index 039d6ed45e2..ccec4d73b6e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project @@ -33,6 +34,7 @@ struct TestUnfuseBatchNormPass : public FunctionPass { // Consider the xla_hlo dialect legal for tests. conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); conversionTarget.addIllegalOp(); PopulateUnfuseBatchNormPatterns(&getContext(), &conversionPatterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 7f7060fef64..0daec32fbab 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { rewriter.setInsertionPointToEnd(block); // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That // method needs to be moved out of there. - Value opResult = xla_lhlo::MapXlaOpToStdScalarOp( - llvm::cast(op), bodyResultTypes, bodyArgs, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, bodyResultTypes, bodyArgs, &rewriter); if (!opResult) { return ConversionPattern::matchFailure(); } @@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. - Value opResult = xla_lhlo::MapXlaOpToStdScalarOp( - llvm::cast(lhlo_op), argType.getElementType(), - llvm::ArrayRef{lhs, rhs}, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, + &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return ConversionPattern::matchSuccess(); @@ -208,9 +208,6 @@ class DataMovementOpConverter : public OpConversionPattern { auto resultType = getXLAOpResultType(op); if (!verifyXLAOpBufferOrTensorSemantics(op)) return ConversionPattern::matchFailure(); - // TODO(b/150203558) Enable once tiling/fusion works in this case. - if (isLHLO && (operandType.getRank() == 0)) - return ConversionPattern::matchFailure(); ArrayAttr indexingMapsAttr = static_cast(*this).getIndexingMapsAttr(op, &rewriter); if (!indexingMapsAttr) return ConversionPattern::matchFailure(); @@ -253,14 +250,13 @@ class BroadcastInDimConverter auto operandShape = operandType.getShape(); SmallVector dimExprs; + AffineMap inputMap = AffineMap::get(b->getContext()); { dimExprs.reserve(nloops); if (broadcastOp.broadcast_dimensions()) { for (const auto& broadcastDim : - enumerate(broadcastOp.broadcast_dimensions() - .getValue() - .getIntValues())) { + enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { int size = broadcastDim.value().getSExtValue(); // TODO(pifon): Add support for args with dynamic shapes for the case // when a dimension of size 1 is broadcasted into dim of size N. @@ -272,58 +268,13 @@ class BroadcastInDimConverter } if (dimExprs.empty()) { // The input is a scalar, i.e. this is a scalar broadcast op. - dimExprs.push_back(b->getAffineConstantExpr(0)); + inputMap = AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()); + } else { + inputMap = AffineMap::get(nloops, /*symbolCount=*/0, dimExprs); } } return b->getAffineMapArrayAttr( - {AffineMap::get(nloops, /*symbolCount=*/0, dimExprs), - b->getMultiDimIdentityMap(nloops)}); - } -}; - -// Special case for scalar broadcast in lhlo. -// TODO(b/150203558) Remove once the bug is fixed. -class ScalarBroadcastInDimConverter - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - xla_lhlo::BroadcastInDimOp broadcastOp, ArrayRef args, - ConversionPatternRewriter& rewriter) const final { - auto operandMemrefType = - broadcastOp.operand().getType().dyn_cast(); - // Only support scalar operands. - if (operandMemrefType.getRank() != 0) return matchFailure(); - auto resultMemrefType = - broadcastOp.output().getType().dyn_cast(); - if (!operandMemrefType || !resultMemrefType) return matchFailure(); - auto broadcastDims = broadcastOp.broadcast_dimensions(); - if (!broadcastDims.hasValue()) return matchFailure(); - - unsigned nloops = resultMemrefType.getRank(); - SmallVector indexingMaps{ - AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))}; - auto loc = broadcastOp.getLoc(); - auto linalgOp = rewriter.create( - loc, ArrayRef{}, broadcastOp.output(), - rewriter.getI64IntegerAttr(0), // args_in - rewriter.getI64IntegerAttr(1), // args_out - rewriter.getArrayAttr(indexingMaps), - GetNParallelLoopsAttrs(nloops, &rewriter), - /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); - - // Add a block to the region. - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - block->addArguments(resultMemrefType.getElementType()); - - rewriter.setInsertionPointToEnd(block); - auto scalar = - rewriter.create(loc, broadcastOp.operand(), llvm::None); - rewriter.create(loc, scalar.getResult()); - rewriter.eraseOp(broadcastOp); - return matchSuccess(); + {inputMap, b->getMultiDimIdentityMap(nloops)}); } }; @@ -537,21 +488,24 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeAddRemoveDimConverter, - ScalarBroadcastInDimConverter, ScalarPointwiseToStandardConverter, SliceConverter >(context); @@ -632,12 +586,15 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter>(context); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f3ee4e38f31..77cd3dc074c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -338,6 +338,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "searchsorted_op_test", + size = "small", + timeout = "moderate", + srcs = ["searchsorted_op_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "svd_op_test", size = "medium", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index f42d51dbb3a..8543e8ea2be 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools +import os import numpy as np @@ -1600,4 +1601,8 @@ class BinaryOpsTest(xla_test.XLATestCase): if __name__ == "__main__": + # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems + os.environ[ + "XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get( + "XLA_FLAGS", "") googletest.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index a2da9815b18..31eb514b14c 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -514,6 +514,13 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], dtype=np.float32)) + def testBFloat16(self): + img = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtypes.bfloat16.as_numpy_dtype) + self._assertForwardOpMatchesExpected(img, [4, 4], expected=np.array( + [[1, 2, 2, 3], [4, 5, 5, 6], [4, 5, 5, 6], [7, 8, 8, 9]], + dtype=np.float32)) + def testAlignCorners3x3To12x12_uint8(self): # TODO(b/72099414): enable the test for TPU when the issue is fixed. if (self.device not in ["XLA_GPU", "XLA_CPU"]): @@ -590,12 +597,14 @@ class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase): ("256x256To299x299", 256, 256, 299, 299), ("512x512To299x299", 512, 512, 299, 299), ("224x224To224x224", 224, 224, 224, 224), + ("224x224To224x224-bfloat", 224, 224, 224, 224, + dtypes.bfloat16.as_numpy_dtype), # This test is disabled because it is very slow. It is slow because # 383 is prime, 383 and 2047 are coprime, and 2048 is large. # ("Disabled_384x72To2048x384", 384, 72, 2048, 384), ) - def test(self, src_y, src_x, dst_y, dst_x): + def test(self, src_y, src_x, dst_y, dst_x, dtype=np.float32): if test.is_built_with_rocm(): self.skipTest("Disabled on ROCm, because it runs out of memory") @@ -613,7 +622,7 @@ class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase): ] self._assertForwardOpMatchesExpected( - np.array(input_data, dtype=np.float32), [dst_y, dst_x], + np.array(input_data, dtype=dtype), [dst_y, dst_x], expected=np.array(result, dtype=np.float32), large_tolerance=True) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index dfa5bc106ed..8bad4da0524 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tests/searchsorted_op_test.py b/tensorflow/compiler/tests/searchsorted_op_test.py new file mode 100644 index 00000000000..d77bd0902d3 --- /dev/null +++ b/tensorflow/compiler/tests/searchsorted_op_test.py @@ -0,0 +1,75 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for XLA implementation of tf.searchsorted.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SearchSorteddOpTest(xla_test.XLATestCase): + + def test1D(self): + # Test against NumPy implementation (which is 1D only). + np.random.seed(1) + for side in ['left', 'right']: + for dtype in [np.float32, np.int32]: + values = np.random.uniform( + low=-1000, high=1000, size=(10,)).astype(dtype) + unsorted = np.random.uniform( + low=-1000, high=1000, size=(20,)).astype(dtype) + + sorted_sequence = np.sort(unsorted) + np_ans = np.searchsorted(sorted_sequence, values, side=side) + + with self.session() as session: + with self.test_scope(): + tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) + tf_out = session.run(tf_ans) + self.assertAllEqual(np_ans, tf_out) + + def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans): + + with self.session() as session: + with self.test_scope(): + tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) + tf_out = session.run(tf_ans) + self.assertAllEqual(correct_ans, tf_out) + + def testLowerBound2DExample(self): + # 2D TensorFlow documentation example. + for dtype in self.float_types | self.int_types: + sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) + values = np.array([[2, 4, 9], [0, 2, 6]], dtype) + correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype) + self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans) + + def testUpperBound2DExample(self): + # 2D TensorFlow documentation example. + for dtype in self.float_types | self.int_types: + sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) + values = np.array([[2, 4, 9], [0, 2, 6]], dtype) + correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype) + self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 420dc04bec3..f1f8b6c353c 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -240,6 +240,22 @@ class WhileTest(xla_test.XLATestCase): self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) xla_context.Exit() + @test_util.enable_control_flow_v2 + def testMapBackPropFalse(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = map_fn.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), + elems, + back_prop=False) + self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) + xla_context.Exit() + def is_compile_on_demand(): return ("TF_XLA_FLAGS" in os.environ and diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index b26b509b067..371a5804008 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -13,9 +13,15 @@ load( "tf_gen_op_wrapper_py", "tf_gpu_kernel_library", ) + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pybind_extension") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", @@ -235,8 +241,8 @@ tf_custom_op_py_library( name = "trt_ops_loader", srcs_version = "PY2AND3", deps = [ + ":_pywrap_py_utils", ":trt_ops", - ":wrap_py_utils", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", @@ -547,12 +553,16 @@ cc_library( ]), ) -tf_py_wrap_cc( - name = "wrap_py_utils", - srcs = ["utils/py_utils.i"], - copts = tf_copts(), +pybind_extension( + name = "_pywrap_py_utils", + srcs = ["utils/py_utils_wrapper.cc"], + link_in_framework = True, + module_name = "_pywrap_py_utils", deps = [ ":py_utils", - "//third_party/python_runtime:headers", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", + "@pybind11", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc index 50d0ae8c000..82e68cbb28d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc @@ -47,7 +47,7 @@ class LoggerRegistryImpl : public LoggerRegistry { private: mutable mutex mu_; mutable std::unordered_map> - registry_ GUARDED_BY(mu_); + registry_ TF_GUARDED_BY(mu_); }; LoggerRegistry* GetLoggerRegistry() { diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 7995163ed44..d9d8a4461a3 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -44,6 +44,7 @@ Status TRTOptimizationPass::Init( if (config == nullptr) { return Status::OK(); } + VLOG(1) << "config = " << config->DebugString(); const auto params = config->parameter_map(); if (params.count("minimum_segment_size")) { minimum_segment_size_ = params.at("minimum_segment_size").i(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 05e6575dc1c..a0524f4a90e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -123,7 +123,15 @@ class TRTEngineOp : public AsyncOpKernel { // input and 2) The index of the IExecutionContext compatible with the input. StatusOr> GetEngine( const std::vector& input_concrete_shapes, - OpKernelContext* ctx, TRTEngineCacheResource* cache_res); + OpKernelContext* ctx, TRTEngineCacheResource* cache_resource); + + // Builds and returns a cuda engine for the input shapes. If building the + // engine fails, enters a dummy entry into the cache_resource cache so we + // don't continually try to build the same failing engine. + StatusOr> BuildEngine( + const std::vector& input_concrete_shapes, int batch_size, + bool use_calibration, TRTInt8Calibrator* calibrator, + TRTEngineCacheResource* cache_resource); // Verify that the input shapes are consistent and can be handled by this op. Status VerifyInputShapes(const std::vector& shapes); @@ -881,6 +889,40 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx, }}); } +StatusOr> TRTEngineOp::BuildEngine( + const std::vector& input_concrete_shapes, int batch_size, + bool use_calibration, TRTInt8Calibrator* calibrator, + TRTEngineCacheResource* cache_resource) { + VLOG(1) << "Building a new TensorRT engine for " << name() + << " with input shapes: " + << TensorShapeUtils::ShapeListString(input_concrete_shapes); + + // Use concrete shapes for implicit batch mode and partial shapes for + // explicit batch mode. + const std::vector& conversion_input_shapes = + use_implicit_batch_ + ? std::vector(input_concrete_shapes.begin(), + input_concrete_shapes.end()) + : input_partial_shapes_; + TrtUniquePtrType engine; + auto status = convert::ConvertGraphDefToEngine( + segment_graph_def_, precision_mode_, batch_size, workspace_size_, + conversion_input_shapes, &logger, cache_resource->allocator_.get(), + calibrator, &engine, use_calibration, use_implicit_batch_, nullptr, + &cache_resource->profiles_); + if (!status.ok()) { + LOG(WARNING) << "Engine creation for " << name() << " failed. " + << "The native segment will be used instead. " + << "Reason: " << status; + // Store an empty engine in the cache for these input shapes so we don't try + // to build the same failing engine again. + cache_resource->cache_.emplace(input_concrete_shapes, + absl::make_unique()); + return status; + } + return engine; +} + StatusOr> TRTEngineOp::GetEngine( const std::vector& input_concrete_shapes, OpKernelContext* ctx, TRTEngineCacheResource* cache_res) { @@ -918,7 +960,32 @@ StatusOr> TRTEngineOp::GetEngine( infer->deserializeCudaEngine(serialized_segment_.c_str(), serialized_segment_.size(), nullptr)); if (!static_engine) { - return std::pair(&empty_context, 0); + if (!allow_build_at_runtime_) { + // Store an empty engine in the cache so we don't try to load the same + // failing engine again. + cache.emplace(input_concrete_shapes, + absl::make_unique()); + return std::pair(&empty_context, 0); + } + if (segment_graph_def_.node().empty()) { + FunctionLibraryRuntime* lib = ctx->function_library(); + auto status = ConstructFunctionHandle(lib, ctx->device()->name()); + if (status.ok()) { + status = + FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_); + } + if (!status.ok()) { + LOG(WARNING) << "Getting segment graph for " << name() << " failed. " + << "Reason: " << status; + } + } + auto result = BuildEngine(input_concrete_shapes, batch_size, + /*use_calibration=*/false, + /*calibrator=*/nullptr, cache_res); + if (!result.ok()) { + return std::pair(&empty_context, 0); + } + static_engine = std::move(result.ValueOrDie()); } auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); @@ -977,36 +1044,16 @@ StatusOr> TRTEngineOp::GetEngine( cache.emplace(input_concrete_shapes, absl::make_unique()); return std::pair(&empty_context, 0); } - TrtUniquePtrType engine; - bool convert_successfully = false; - LOG(INFO) << "Building a new TensorRT engine for " << name() - << " with input shapes: " - << TensorShapeUtils::ShapeListString(input_concrete_shapes); - - // Use concrete shapes for implicit batch mode and partial shapes for - // explicit batch mode. - const std::vector& conversion_input_shapes = - use_implicit_batch_ - ? std::vector(input_concrete_shapes.begin(), - input_concrete_shapes.end()) - : input_partial_shapes_; // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. - auto status = convert::ConvertGraphDefToEngine( - segment_graph_def_, precision_mode_, batch_size, workspace_size_, - conversion_input_shapes, &logger, allocator, calibrator_.get(), &engine, - use_calibration_, use_implicit_batch_, &convert_successfully, - &cache_res->profiles_); - if (!status.ok()) { - LOG(WARNING) << "Engine creation for " << name() << " failed. " - << "The native segment will be used instead. " - << "Reason: " << status; - // Store an empty engine in the cache for these input shapes so we don't - // try to build the same failing engine again. - cache.emplace(input_concrete_shapes, absl::make_unique()); + auto result = BuildEngine(input_concrete_shapes, batch_size, + use_calibration_, calibrator_.get(), cache_res); + if (!result.ok()) { return std::pair(&empty_context, 0); } + TrtUniquePtrType engine = + std::move(result.ValueOrDie()); std::vector> exec_context; TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts( engine.get(), exec_context)); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index 8ef72ba44d5..2c5821df6ac 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -71,7 +71,7 @@ class CreateTRTResourceHandle : public OpKernel { string resource_name_; Tensor handle_; mutex mutex_; - bool initialized_ GUARDED_BY(mutex_) = false; + bool initialized_ TF_GUARDED_BY(mutex_) = false; TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle); }; diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.i b/tensorflow/compiler/tf2tensorrt/utils/py_utils.i deleted file mode 100644 index d6e8eac5836..00000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.i +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Wrap trt_conversion */ -%{ -#define SWIG_FILE_WITH_INIT -%} - -%{ -struct version_struct{ - int vmajor; - int vminor; - int vpatch; -}; - -PyObject* version_helper(version_struct* in) { - PyObject *tuple(nullptr); - tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); - if (!tuple) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "Tuple creation from version structure failed!"); - } - return NULL; - } - return tuple; -} - -%} - -%typemap(out) version_struct { - PyObject *tuple = version_helper(&$1); - if (!tuple) SWIG_fail; - $result = tuple; -} - -%{ -#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" -%} - -%ignore ""; -%rename("%s") get_linked_tensorrt_version; -%rename("%s") get_loaded_tensorrt_version; -%rename("%s") is_tensorrt_enabled; - -%{ - -version_struct get_linked_tensorrt_version() { - // Return the version at the link time. - version_struct s; - tensorflow::tensorrt::GetLinkedTensorRTVersion( - &s.vmajor, &s.vminor, &s.vpatch); - return s; -} - -version_struct get_loaded_tensorrt_version() { - // Return the version from the loaded library. - version_struct s; - tensorflow::tensorrt::GetLoadedTensorRTVersion( - &s.vmajor, &s.vminor, &s.vpatch); - return s; -} - -bool is_tensorrt_enabled() { - return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); -} - -%} - -version_struct get_linked_tensorrt_version(); -version_struct get_loaded_tensorrt_version(); -bool is_tensorrt_enabled(); - -%rename("%s") ""; diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc new file mode 100644 index 00000000000..0d7819931b1 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/pybind11/pybind11.h" +#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" + +std::tuple get_linked_tensorrt_version() { + int major, minor, patch; + tensorflow::tensorrt::GetLinkedTensorRTVersion(&major, &minor, &patch); + return std::tuple{major, minor, patch}; +} + +std::tuple get_loaded_tensorrt_version() { + int major, minor, patch; + tensorflow::tensorrt::GetLoadedTensorRTVersion(&major, &minor, &patch); + return std::tuple{major, minor, patch}; +} + +PYBIND11_MODULE(_pywrap_py_utils, m) { + m.doc() = "_pywrap_py_utils: Various TensorRT utilities"; + m.def("get_linked_tensorrt_version", get_linked_tensorrt_version, + "Return the compile time TensorRT library version as the tuple " + "(Major, Minor, Patch)."); + m.def("get_loaded_tensorrt_version", get_loaded_tensorrt_version, + "Return the runtime time TensorRT library version as the tuple " + "(Major, Minor, Patch)."); + m.def("is_tensorrt_enabled", tensorflow::tensorrt::IsGoogleTensorRTEnabled, + "Returns True if TensorRT is enabled."); +} diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 97995fa186a..8e345254f75 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -136,7 +136,7 @@ struct EngineContext { TrtUniquePtrType cuda_engine; Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { if (idx >= execution_context.size()) { return errors::Internal("Requested engine context with index ", idx, ", but only ", execution_context.size(), @@ -152,7 +152,7 @@ struct EngineContext { // for inference at a time therefore we need a mutex. More details at // https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#thread-safety std::vector> execution_context - GUARDED_BY(mu); + TF_GUARDED_BY(mu); }; // Contains the context required to build the calibration data. @@ -174,8 +174,8 @@ class CalibrationContext { private: mutex mu_; - bool terminated_ GUARDED_BY(mu_) = false; - std::string calibration_table_ GUARDED_BY(mu_); + bool terminated_ TF_GUARDED_BY(mu_) = false; + std::string calibration_table_ TF_GUARDED_BY(mu_); }; ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName; diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 30bd1eff8eb..a6f88df7e40 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -5,6 +5,7 @@ load( ) load( "//tensorflow/core/platform:build_config.bzl", + "tf_additional_tensor_coding_deps", "tf_proto_library", "tf_proto_library_cc", ) @@ -37,6 +38,7 @@ package_group( "//learning/brain/tools/tf_replay/...", "//tensorflow/...", "//tensorflow_models/...", + "//third_party/mlperf/submissions/training/v0_7/models/...", ], ) @@ -176,8 +178,8 @@ cc_library( "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -236,18 +238,20 @@ cc_library( features = ["fully_static_link"], linkstatic = 1, visibility = [":friends"], - # Note, we specifically remove MKL dependencies so the standalone does - # not require the MKL binary blob. + # Note, we specifically removed MKL and multithreaded dependencies so the + # standalone does not require the MKL binary blob or threading libraries. + # + # TODO(ebrevdo): Remove tf_additoinal_tensor_coding_deps in favor of + # absl/strings:cord when we update absl to a newer version. deps = [ - "//tensorflow/core/framework:numeric_types", - "//third_party/eigen3", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/synchronization", - ], + "//third_party/eigen3", + "//tensorflow/core/framework:numeric_types", + ] + tf_additional_tensor_coding_deps(), alwayslink = 1, ) @@ -692,6 +696,7 @@ cc_library( srcs = ["mlir_bridge_pass.cc"], hdrs = ["mlir_bridge_pass.h"], deps = [ + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:device_util", @@ -711,6 +716,7 @@ cc_library( ], deps = [ ":mlir_bridge_pass", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", "//tensorflow/core:core_cpu", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index f0aebc9b543..eadd05fcee0 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -54,6 +54,7 @@ namespace tensorflow { namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, + const NameAttrList& func, std::vector* args) { auto client = ctx->compiler()->client(); std::vector arg_must_be_compile_time_constant(expressions.size()); @@ -78,9 +79,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(absl::optional value, expressions[i]->ResolveConstant(client)); if (!value.has_value()) { - return errors::InvalidArgument( - "Argument to function must be a compile-time constant, but " - "unable to resolve argument value to a constant."); + return errors::InvalidArgument(absl::StrCat( + "Argument ", i, " to function '", func.name(), + "' must be a compile-time constant, but ", + "unable to resolve argument value to a constant.")); } arg.kind = XlaCompiler::Argument::kConstant; arg.constant_value = *value; @@ -249,8 +251,8 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto graph = compiler->GetGraph(fbody); - TF_RETURN_IF_ERROR( - PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions, + func, &arguments)); bool add_token_input_output = func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8571c503299..5f1c2f28ba4 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,6 +55,7 @@ tf_kernel_library( "index_ops.cc", "l2loss_op.cc", "listdiff_op.cc", + "lower_upper_bound_ops.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -149,6 +150,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 9f0ec65bb71..b60a13972a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -29,10 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index ba11b12fa2a..63e3f185421 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 8b27e8e85a3..38d8056d3e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -25,10 +26,15 @@ class IdentityOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - // Forwards using the underlying op_kernel_context so both tensor and - // resource values are forwarded correctly. - ctx->op_kernel_context()->set_output(i, - ctx->op_kernel_context()->input(i)); + if (IsTensorListInput(ctx, i)) { + ctx->SetTensorListOutput(i, ctx->Input(i)); + } else { + DCHECK(ctx->input_type(i) != DT_VARIANT); + // Forwards using the underlying op_kernel_context so both tensor and + // resource values are forwarded correctly. + ctx->op_kernel_context()->set_output( + i, ctx->op_kernel_context()->input(i)); + } } } @@ -48,7 +54,7 @@ REGISTER_XLA_OP(Name("IdentityN") IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); -REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); +REGISTER_XLA_OP(Name("StopGradient").AllowVariantTypes(), IdentityOp); REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc new file mode 100644 index 00000000000..0eacf8812f1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +// Builds a LowerBound or UpperBound op, the distinction lying in +// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp. +// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row +// are considered, and their sorted nature is not fully exploited. +void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype, + xla::ComparisonDirection comparison_direction) { + const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs"); + const TensorShape values_shape = ctx->InputShape("values"); + const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs"); + const xla::XlaOp values = ctx->Input("values"); + + // We are assuming both inputs are 2D, which they will be given the current + // implementation of tf.searchsorted. + OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2, + errors::FailedPrecondition("sorted_inputs must be 2D")); + OP_REQUIRES(ctx, values_shape.dims() == 2, + errors::FailedPrecondition("values must be 2D")); + + // Add a new inner dimension to values, to allow broadcasting along the inner + // dimension of sorted_sequence. + auto new_values_shape = values_shape; + new_values_shape.InsertDim(/* d */ 2, /* size */ 1); + auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes()); + + // Add a new penultimate dimension to sorted_inputs, to allow broadcasting of + // sorted_sequence entries for each value. + auto new_sorted_inputs_shape = sorted_inputs_shape; + new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1); + auto sorted_inputs_reshaped = + xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes()); + + // We are relying on broadcasting to compare each value against each entry in + // the associated sorted_inputs row. + // The reshapes above leave the tensors with equal rank of 3, so broadcast + // dimensions are not explicitly specified. + auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {}, + comparison_direction); + + const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype); + + // Convert boolean comparison results to integers so we can sum them. + auto comparison_int = + XlaHelpers::ConvertElementType(comparison, accumulation_type); + + // Sum the comparison results over the inner dimension to find the index for + // each value. + xla::XlaBuilder* builder = ctx->builder(); + auto reduced = + xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {2}); + + ctx->SetOutput(0, reduced); +} + +class LowerBoundOp : public XlaOpKernel { + public: + explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp); + +class UpperBoundOp : public XlaOpKernel { + public: + explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 7ac4cb8fb06..6d0d569724f 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -17,125 +17,32 @@ limitations under the License. #include -#include "absl/container/flat_hash_set.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_os_ostream.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { -// Dumps the MLIR module to disk. -// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can -// be created). -static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) { - std::string prefix = GetDumpDirFromEnvVar(); - if (prefix.empty()) { - return; - } - - auto* env = tensorflow::Env::Default(); - auto status = env->RecursivelyCreateDir(prefix); - if (!status.ok()) { - LOG(WARNING) << "cannot create directory '" + prefix + - "': " + status.error_message(); - return; - } - prefix += "/" + file_prefix.str(); - if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) { - LOG(WARNING) << "cannot create unique filename, won't dump MLIR module."; - return; - } - - std::unique_ptr file_writer; - status = env->NewWritableFile(prefix, &file_writer); - if (!status.ok()) { - LOG(WARNING) << "cannot open file '" + prefix + - "': " + status.error_message(); - return; - } - - // Print the module to a string before writing to the file. - std::string txt_module; - { - llvm::raw_string_ostream os(txt_module); - module.print(os); - } - - status = file_writer->Append(txt_module); - if (!status.ok()) { - LOG(WARNING) << "error writing to file '" + prefix + - "': " + status.error_message(); - return; - } - (void)file_writer->Close(); - VLOG(1) << "Dumped MLIR module to " << prefix; -} - // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated // and attached to a "compile" operation, whose result is fed to an "execute" // operation. The kernel for these operations is responsible to lower the // encapsulated graph to a particular device. -Status MlirBridgePass::Run(const DeviceSet& device_set, - const ConfigProto& config_proto, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) { +Status MlirBridgePass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { if (!config_proto.experimental().enable_mlir_bridge()) { VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled"; return Status::OK(); } VLOG(1) << "Running MLIR Bridge Pass"; - - GraphDebugInfo debug_info; - mlir::MLIRContext context; - GraphImportConfig import_config; - import_config.graph_as_function = true; - import_config.control_outputs = *control_ret_node_names; - TF_ASSIGN_OR_RETURN(auto module_ref, - ConvertGraphToMlir(**graph, debug_info, *flib_def, - import_config, &context)); - - AddDevicesToOp(*module_ref, &device_set); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_"); - - // Run the bridge now TF_RETURN_IF_ERROR( - mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1))); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_"); - - GraphExportConfig export_config; - export_config.graph_as_function = true; - absl::flat_hash_set control_ret_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertMlirToGraph(*module_ref, export_config, graph, flib_def, - &control_ret_nodes), - "Error converting MLIR module back to graph"); - - control_ret_node_names->clear(); - control_ret_node_names->reserve(control_ret_nodes.size()); - for (const auto* node : control_ret_nodes) - control_ret_node_names->push_back(node->name()); - - *control_rets_updated = true; + mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); return Status::OK(); } - -Status MlirBridgeV1CompatPass::Run( - const GraphOptimizationPassOptions& options) { +Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) { // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return Status::OK(); @@ -145,31 +52,8 @@ Status MlirBridgeV1CompatPass::Run( } VLOG(1) << "Running MLIR Bridge V1 Compat Pass"; - - GraphDebugInfo debug_info; - mlir::MLIRContext context; - GraphImportConfig import_config; - import_config.upgrade_legacy = true; - TF_ASSIGN_OR_RETURN( - auto module_ref, - ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, - import_config, &context)); - - AddDevicesToOp(*module_ref, options.device_set); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_"); - - // Run the bridge now - TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat( - *module_ref, /*enable_logging=*/VLOG_IS_ON(1))); - - if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_"); - - GraphExportConfig export_config; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertMlirToGraph(*module_ref, export_config, options.graph, - options.flib_def), - "Error converting MLIR module back to graph"); + TF_RETURN_IF_ERROR( + mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1))); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index e7f3fee79ca..b7f8ef203f7 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -16,28 +16,42 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ -#include "tensorflow/core/common_runtime/function_optimization_registry.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" namespace tensorflow { // This pass uses MLIR to implement all the conversion steps to target XLA from // a TensorFlow Function Graph. It is meant to expose a very limited set of // functionalities during the bring-up of MLIR-based bridge. -class MlirBridgePass : public FunctionOptimizationPass { +class MlirBridgePass : public MlirOptimizationPass { public: - Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) override; + llvm::StringRef name() const override { return "bridge"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_bridge(); + } + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; }; // This pass uses MLIR to implement all the conversion steps to target XLA from // a TensorFlow V1 Graph. It is meant to expose a very limited set of // functionalities during the bring-up of MLIR-based bridge. -class MlirBridgeV1CompatPass : public GraphOptimizationPass { +class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + llvm::StringRef name() const override { return "bridge"; } + + bool IsEnabled(const ConfigProto& config_proto) const override { + return config_proto.experimental().enable_mlir_bridge(); + } + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc index ac6e54d4e76..21791ff4427 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass_registration.cc @@ -16,15 +16,18 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" -#include "tensorflow/core/common_runtime/function_optimization_registry.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +namespace { +constexpr int kMlirBridgePriority = 10; +} -static function_optimization_registration::FunctionOptimizationPassRegistration - register_mlir_bridge_pass(std::make_unique()); +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_bridge_pass(kMlirBridgePriority, + std::make_unique()); -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, - MlirBridgeV1CompatPass); +static mlir_pass_registration::MlirV1CompatOptimizationPassRegistration + register_v1_compat_mlir_bridge_pass( + kMlirBridgePriority, std::make_unique()); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 3c02f9dd2e2..9303e2e9330 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -86,9 +86,10 @@ Status ConvertOutputInfo(const tf2xla::Config& config, } // namespace -Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def, - const tf2xla::Config& config, - xla::XlaComputation* computation) { +Status ConvertGraphDefToXlaViaMlir( + GraphDef graph_def, const tf2xla::Config& config, + xla::XlaComputation* computation, absl::string_view debug_info_filename, + absl::string_view debug_info_path_begin_marker) { // AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two // purposes: (1) It creates a placeholder node for each feed, so that // PruneGraphDefInfo can prune away the node containing the feed. (2) It @@ -115,7 +116,24 @@ Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def, TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs)); GraphDebugInfo debug_info; + if (!debug_info_filename.empty()) { + TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info)); + + if (!debug_info_path_begin_marker.empty()) { + for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) { + std::string* file_name = debug_info.mutable_files(i); + size_t location = + file_name->rfind(std::string(debug_info_path_begin_marker)); + if (location != -1) { + *file_name = file_name->substr(location + + debug_info_path_begin_marker.length()); + } + } + } + } + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9ced6e682fc..bcdfd1c6a8e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/versions.pb.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/dump_graph.h" @@ -128,19 +130,31 @@ Status ConvertGraphToXla(std::unique_ptr graph, return Status::OK(); } -void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { - for (auto& node : *graph_def->mutable_node()) { +Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { + auto update_var_handle_op_node = [](NodeDef& node) -> Status { if (node.op() == "VarHandleOp") { node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); + const auto& it = node.attr().find("allowed_devices"); + if (it != node.attr().end()) { + if (!it->second.list().s().empty()) { + // TODO(b/149512838): Support non-empty allowed devices. + return errors::InvalidArgument( + "VarHandleOp with non-empty allowed devices is not supported."); + } + node.mutable_attr()->erase("allowed_devices"); + } } + return Status::OK(); + }; + for (auto& node : *graph_def->mutable_node()) { + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } for (auto& fn : *graph_def->mutable_library()->mutable_function()) { for (auto& node : *fn.mutable_node_def()) { - if (node.op() == "VarHandleOp") { - node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); - } + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } } + return Status::OK(); } } // namespace @@ -149,7 +163,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { std::unique_ptr graph; - ConvertVarHandlesToAotVarHandles(&graph_def); + TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def)); TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR( ConvertGraphToXla(std::move(graph), config, client, computation)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index a1c8806bba5..587d3c2febf 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ #define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -34,10 +35,16 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation); -// Similar to ConvertGraphDefToXla, but uses MLIR. -Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def, - const tf2xla::Config& config, - xla::XlaComputation* computation); +// Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information. +// +// debug_info_filename: the file for the debug information proto. +// debug_info_path_begin_marker: if not empty, file pathes in the debug +// information are trimmed from the beginning to the first appearance of the +// marker. +Status ConvertGraphDefToXlaViaMlir( + GraphDef graph_def, const tf2xla::Config& config, + xla::XlaComputation* computation, absl::string_view debug_info_filename, + absl::string_view debug_info_path_begin_marker); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 5420cf3e04f..3870a673e4e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,7 +28,9 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, buffer_infos_(static_data.buffer_infos_), arg_index_table_(static_data.arg_index_table_), num_args_(static_data.num_args_), + num_variables_(static_data.num_variables_), arg_names_(static_data.arg_names_), + variable_names_(static_data.variable_names_), result_names_(static_data.result_names_), program_shape_(static_data.program_shape_), hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { @@ -63,6 +65,8 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { namespace { +constexpr int kNotFound = -1; + // Linear search through `names` looking for a match with `name`. Returns -1 if // the name isn't found, or is empty. // @@ -72,7 +76,6 @@ int LookupNameIndex(const string& name, const char** names) { // for AOT try the setting the tfcompile --gen_name_to_index flag. assert(names != nullptr); - constexpr int kNotFound = -1; if (name.empty()) { return kNotFound; } @@ -90,6 +93,14 @@ int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { return LookupNameIndex(name, arg_names_); } +int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const { + int index = LookupNameIndex(name, variable_names_); + if (index == kNotFound) { + return kNotFound; + } + return num_args_ - num_variables_ + index; +} + int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { return LookupNameIndex(name, result_names_); } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 5e452b50e71..04d9086ce4c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -76,12 +76,16 @@ class XlaCompiledCpuFunction { // There are num_args entry parameters. int64 num_args_ = 0; + // There are num_variables variables. + int64 num_variables_ = 0; + // The 0-based index of the result tuple, in the temp buffers. size_t result_index_ = 0; // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; const char** result_names_ = nullptr; // [Optional] Arg and result shapes. @@ -150,6 +154,8 @@ class XlaCompiledCpuFunction { int num_args() const { return num_args_; } + int num_variables() const { return num_variables_; } + // Returns the size of entry parameter `idx`. // // There is a static version of this method on tfcompile generated subclasses @@ -212,10 +218,11 @@ class XlaCompiledCpuFunction { // ------------------------------ // Methods for extracting optional metadata. - // Returns true iff data is available for the Lookup{Arg,Result}Index methods. - // E.g. the data might not be compiled into the binary for AOT. + // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index + // methods. E.g. the data might not be compiled into the binary for AOT. bool HasNameIndices() const { - return arg_names_ != nullptr && result_names_ != nullptr; + return arg_names_ != nullptr && variable_names_ != nullptr && + result_names_ != nullptr; } // Returns the 0-based index for the argument with the given `name`. @@ -226,6 +233,14 @@ class XlaCompiledCpuFunction { // Recommended usage is to capture this in a variable for re-use. int LookupArgIndex(const string& name) const; + // Returns the 0-based index for the variable with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupVariableIndex(const string& name) const; + // Returns the 0-based index for the result with the given `name`. // Returns -1 if the name wasn't found, or data isn't available. // @@ -280,6 +295,11 @@ class XlaCompiledCpuFunction { static_data->num_args_ = num_args; } + static void set_static_data_num_variables(StaticData* static_data, + int64 num_variables) { + static_data->num_variables_ = num_variables; + } + static void set_static_data_result_index(StaticData* static_data, size_t result_index) { static_data->result_index_ = result_index; @@ -290,6 +310,11 @@ class XlaCompiledCpuFunction { static_data->arg_names_ = arg_names; } + static void set_static_data_variable_names(StaticData* static_data, + const char** variable_names) { + static_data->variable_names_ = variable_names; + } + static void set_static_data_result_names(StaticData* static_data, const char** result_names) { static_data->result_names_ = result_names; @@ -334,6 +359,9 @@ class XlaCompiledCpuFunction { // The number of incoming arguments. const int32 num_args_; + // The number of incoming variables. + const int32 num_variables_; + // Backing memory for buffer_table_ and args_, the latter depending on // AllocMode. void* alloc_buffer_table_ = nullptr; @@ -346,6 +374,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3ea62882dcb..c30b1c0e17d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" @@ -1174,51 +1173,6 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } -Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, const NodeDef& node_def, - absl::Span args, - absl::Span result_types, CompilationResult* result) { - // TODO(b/74182462): We implement this by creating a new dummy Graph including - // _Arg nodes, and let CompileGraph walk it. This could be optimized. - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - Status status; - // First create the actual node we care about computing. - Node* main_node = graph->AddNode(node_def, &status); - TF_RETURN_IF_ERROR(status); - - // Create dummy _Arg nodes. Link these to `node` and also via a control - // dependency edge to the _SOURCE node. - for (int64 i = 0; i < args.size(); ++i) { - Node* node; - string arg_name = absl::StrCat("_arg", i); - Status status = - NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) - .ControlInput(graph->source_node()) - .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE - : args[i].type) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - graph->AddEdge(node, 0, main_node, i); - } - - // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < result_types.size(); ++i) { - Node* node; - string retval_name = absl::StrCat("_retval", i); - Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) - .Input(main_node, i) - .Attr("T", result_types[i]) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - } - FixupSourceAndSinkEdges(graph.get()); - - return CompileGraph(options, node_def.name(), std::move(graph), args, result); -} - namespace { // Check that the ops of all non-functional nodes have been registered. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 5ec5866632b..6a56136a9f6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -375,14 +375,6 @@ class XlaCompiler { std::unique_ptr graph, absl::Span args, CompilationResult* result); - // Compiles a single Op, given by `node_def`, into an - // xla::XlaComputation. Similar to CompileFunction but takes a single Op as - // input. - Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, - absl::Span args, - absl::Span result_types, - CompilationResult* result); - // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 0392cc7d345..0deaa1ea8fb 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -49,9 +49,9 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold -// the actual strings in nonempty_names, and hold arrays of pointers in -// name_ptrs, terminated by a nullptr entry. +// Collect names from `entries`, where T is one of +// tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names, +// and hold arrays of pointers in name_ptrs, terminated by a nullptr entry. template void CollectNames(const T& entries, std::vector* nonempty_names, std::vector* name_ptrs) { @@ -154,14 +154,28 @@ XlaJitCompiledCpuFunction::Compile( &jit->static_data_, jit->arg_index_table_.data()); XlaCompiledCpuFunction::set_static_data_num_args( &jit->static_data_, jit->arg_index_table_.size()); + XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_, + config.variable_size()); XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); + + auto variable_copy = config.variable(); + for (auto& var : variable_copy) { + if (var.name().empty()) { + var.set_name(var.node_name()); + } + } + CollectNames(variable_copy, &jit->nonempty_variable_names_, + &jit->variable_names_); + CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_, jit->arg_names_.data()); + XlaCompiledCpuFunction::set_static_data_variable_names( + &jit->static_data_, jit->variable_names_.data()); XlaCompiledCpuFunction::set_static_data_result_names( &jit->static_data_, jit->result_names_.data()); XlaCompiledCpuFunction::set_static_data_program_shape( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index 11fc4571189..107968b184d 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -77,8 +77,10 @@ class XlaJitCompiledCpuFunction { // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // data to refer to. std::vector nonempty_arg_names_; + std::vector nonempty_variable_names_; std::vector nonempty_result_names_; std::vector arg_names_; + std::vector variable_names_; std::vector result_names_; // The backing data for the program shape. The proto form of program shape is diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index f5d6b5231ac..880cb5939b6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -210,6 +210,9 @@ TEST(XlaJitCompiledCpuFunction, Sum) { EXPECT_EQ(function.LookupResultIndex("x_name"), -1); EXPECT_EQ(function.LookupResultIndex("y_name"), -1); + EXPECT_EQ(0, function.num_variables()); + EXPECT_EQ(function.LookupVariableIndex("x"), -1); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); @@ -252,6 +255,14 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { EXPECT_EQ(*static_cast(function.result_data(0)), 100); EXPECT_EQ(*static_cast(function.result_data(1)), 420); + // Check name to index lookups. + EXPECT_TRUE(function.HasNameIndices()); + + EXPECT_EQ(2, function.num_args()); + + EXPECT_EQ(1, function.num_variables()); + EXPECT_EQ(function.LookupVariableIndex("myvar"), 1); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a1941cc5fdf..a1c45a4bf30 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -50,7 +50,8 @@ XlaCompiler* XlaOpKernelContext::compiler() const { } // Retrieves an XlaExpression that was allocated by a previous Op. -static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { +const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor( + const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); CHECK(expression->kind() != XlaExpression::Kind::kInvalid) @@ -59,8 +60,8 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { } // Assigns an XlaExpression to a tensor on an XLA compilation device. -static void AssignExpressionToTensor(Tensor* tensor, - const XlaExpression& value) { +void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); CHECK(expression->kind() == XlaExpression::Kind::kInvalid) @@ -396,7 +397,8 @@ namespace { Status ReadVariableInputTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, TensorShape* shape, xla::XlaOp* value) { - const XlaExpression* expression = CastExpressionFromTensor(tensor); + const XlaExpression* expression = + XlaOpKernelContext::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -486,7 +488,8 @@ void XlaOpKernelContext::SetOutputExpression(int index, TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); } - AssignExpressionToTensor(context_->mutable_output(index), expression); + XlaOpKernelContext::AssignExpressionToTensor( + expression, context_->mutable_output(index)); return Status::OK(); }(); if (!status.ok()) { @@ -536,7 +539,8 @@ namespace { Status AssignVariableTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { - const XlaExpression* expression = CastExpressionFromTensor(tensor); + const XlaExpression* expression = + XlaOpKernelContext::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 27b198f8bee..d72dd3972d3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -278,6 +278,13 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::XlaComputation* GetOrCreateMul(const DataType type); + // Assigns an XlaExpression to a tensor on an XLA compilation device. + static void AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor); + + // Retrieves an XlaExpression that was assigned to the specified tensor. + static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); + private: // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index c6f6ffb2853..7839ae95dc0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -229,11 +229,11 @@ class XlaOpRegistry { }; // Map from compilation device names to a description of the backend. - std::unordered_map backends_ GUARDED_BY(mutex_); + std::unordered_map backends_ TF_GUARDED_BY(mutex_); // Map from Tensorflow device names to the corresponding JIT device metadata. std::unordered_map compilation_devices_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); // A description of a Tensorflow operator that can be compiled to XLA. struct OpRegistration { @@ -292,7 +292,7 @@ class XlaOpRegistry { // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. std::unordered_map>> ops_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? bool jit_kernels_registered_ = false; @@ -301,7 +301,7 @@ class XlaOpRegistry { // registrations created by RegisterCompilationKernels() and // RegisterDeviceKernels(). std::vector> - kernel_registrars_ GUARDED_BY(mutex_); + kernel_registrars_ TF_GUARDED_BY(mutex_); }; // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d6d154b2506..a2993058321 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -17,6 +17,7 @@ package_group( "//tensorflow/compiler/...", "//tensorflow/python/tpu/...", "//third_party/py/jax/...", + "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 33d1de370de..4211b9a8b1c 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -134,10 +134,10 @@ class ClientLibrary { tensorflow::mutex service_mutex_; // Guards the singleton creation state. std::unordered_map> - local_instances_ GUARDED_BY(service_mutex_); + local_instances_ TF_GUARDED_BY(service_mutex_); std::unordered_map> - compile_only_instances_ GUARDED_BY(service_mutex_); + compile_only_instances_ TF_GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); }; diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 11a8e27af11..17fb4c3c369 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/prng.h" #include +#include #include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -134,14 +135,26 @@ XlaOp Uint32sToUint64(std::array u32s) { ConstantR0WithType(builder, U64, 32)); } -// Given the initial state and the request number of random numbers to be +// Given the initial state and the request shape of random numbers to be // generated, returns the input for the random number generator and a new state. std::pair GetThreeFryInputsAndUpdatedState( - XlaOp initial_state, const int64 size) { + XlaOp initial_state, const Shape& shape) { XlaBuilder* builder = initial_state.builder(); - XlaOp input_u64 = Iota(builder, U64, size); - input_u64 = input_u64 + initial_state; - XlaOp new_state = initial_state + ConstantR0(builder, size); + auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions()); + // initial_state is an R1, so reshape it to a scalar. + auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions()); + int64 trailing_dims_product = 1; + for (int64 i = shape.rank() - 1; i >= 0; --i) { + if (shape.dimensions(i) < 2) { + continue; + } + input_u64 = + input_u64 + (Iota(builder, u64_shape, i) * + ConstantR0(builder, trailing_dims_product)); + trailing_dims_product *= shape.dimensions(i); + } + XlaOp new_state = + initial_state + ConstantR0(builder, ShapeUtil::ElementsIn(shape)); return std::make_pair(Uint64ToUint32s(input_u64), new_state); } @@ -149,11 +162,46 @@ std::pair GetThreeFryInputsAndUpdatedState( // implementation. Returns the random bits and the new state. RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { XlaBuilder* builder = key.builder(); + // Try to split the shape on a dimension > 1 into two halves, each + // representing a U32 value. + std::vector half_shape_dims; + std::vector padded_full_shape_dims; + int64 split_dim = -1; + for (int64 i = 0; i < shape.rank(); ++i) { + if (shape.dimensions(i) > 1 && split_dim < 0) { + half_shape_dims.push_back(CeilOfRatio(shape.dimensions(i), 2)); + // Create a new trivial dim for the later concat, which is more friendly + // to sharding propagation. + half_shape_dims.push_back(1); + split_dim = i; + padded_full_shape_dims.push_back(half_shape_dims[i] * 2); + } else { + half_shape_dims.push_back(shape.dimensions(i)); + padded_full_shape_dims.push_back(shape.dimensions(i)); + } + } + auto half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims); + if (split_dim >= 0) { + std::pair inputs_state = + GetThreeFryInputsAndUpdatedState(initial_state, half_shape); + ThreeFry2x32State inputs = inputs_state.first; + ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + XlaOp result = ConcatInDim(builder, outputs, split_dim + 1); + result = Reshape(result, padded_full_shape_dims); + if (shape.dimensions(split_dim) % 2 != 0) { + result = Slice(result, std::vector(shape.rank(), 0), + shape.dimensions(), std::vector(shape.rank(), 1)); + } + return {result, inputs_state.second}; + } + // Use an R1 shape if the previous attempt failed. const int64 size = ShapeUtil::ElementsIn(shape); const int64 half_size = CeilOfRatio(size, 2); const bool size_is_odd = (half_size * 2 != size); std::pair inputs_state = - GetThreeFryInputsAndUpdatedState(initial_state, half_size); + GetThreeFryInputsAndUpdatedState( + initial_state, + ShapeUtil::MakeShape(shape.element_type(), {half_size})); ThreeFry2x32State inputs = inputs_state.first; ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); if (size_is_odd) { @@ -167,14 +215,12 @@ RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { // Generates random 64bits with the given shape using the Three Fry // implementation. Returns the random bits and the new state. RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) { - const int64 size = ShapeUtil::ElementsIn(shape); std::pair inputs_state = - GetThreeFryInputsAndUpdatedState(initial_state, size); + GetThreeFryInputsAndUpdatedState(initial_state, shape); ThreeFry2x32State inputs = inputs_state.first; ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); XlaOp result = Uint32sToUint64(outputs); - return {Reshape(result, AsInt64Slice(shape.dimensions())), - inputs_state.second}; + return {result, inputs_state.second}; } // The key of the Philox random number generator. diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 7d8f433bac8..1ea713467f8 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/slicing.h" +#include #include +#include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -24,6 +26,18 @@ limitations under the License. namespace xla { +XlaOp DynamicStridedSlice(XlaOp input, absl::Span base_indices, + absl::Span window_sizes, + absl::Span strides) { + XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes); + if (std::any_of(strides.begin(), strides.end(), + [](int64 stride) { return stride != 1; })) { + sliced_input = Slice(sliced_input, std::vector(window_sizes.size()), + window_sizes, strides); + } + return sliced_input; +} + XlaOp SliceInMinorDims(XlaOp x, absl::Span start, absl::Span end) { XlaBuilder* builder = x.builder(); diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index cf83d63cec2..e6b72890b7d 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -22,6 +22,13 @@ limitations under the License. namespace xla { +// Slices input starting from the base_indices and within the window_sizes, +// using the supplied strides. This is the equivalent of the Python slicing op +// [base_indices : base_indices+window_sizes : stride]. +XlaOp DynamicStridedSlice(XlaOp input, absl::Span base_indices, + absl::Span window_sizes, + absl::Span strides); + // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index df070d97ff7..afe115deda8 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -346,6 +346,23 @@ StatusOr>> LocalClient::Compile( VLOG(3) << "Set device ordinal to default value of: " << updated_options.device_ordinal(); } + if (options.has_device_assignment()) { + if (options.device_assignment().replica_count() != options.num_replicas()) { + return InvalidArgument( + "Mismatched number of replicas for device " + "assignment and computation (%d vs %d).\n%s", + options.device_assignment().replica_count(), options.num_replicas(), + options.device_assignment().ToString()); + } + if (options.device_assignment().computation_count() != + options.num_partitions()) { + return InvalidArgument( + "Mismatched number of partitions for device " + "assignment and computation (%d vs %d).\n%s", + options.device_assignment().computation_count(), + options.num_partitions(), options.device_assignment().ToString()); + } + } TF_ASSIGN_OR_RETURN(std::vector> executables, local_service_->CompileExecutables( computation, argument_layouts, updated_options)); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index da9ac3553ad..888db7536e4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -489,8 +489,9 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, - ReshapeInternal(reshaped_shape, operand)); + TF_ASSIGN_OR_RETURN( + XlaOp reshaped_operand, + ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1)); // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -498,12 +499,10 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), unop, {operand}); + return AddOpWithShape(unop, shape, {operand}); }); } @@ -511,31 +510,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, absl::optional direction) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferBinaryOpShape( binop, *lhs_shape, *rhs_shape, broadcast_dimensions)); - *instr.mutable_shape() = shape.ToProto(); - if (binop == HloOpcode::kCompare) { - if (!direction.has_value()) { - return InvalidArgument( - "kCompare expects a ComparisonDirection, but none provided."); - } - instr.set_comparison_direction(ComparisonDirectionToString(*direction)); - } else if (direction.has_value()) { - return InvalidArgument( - "A comparison direction is provided for a non-compare opcode: %s.", - HloOpcodeString(binop)); - } const int64 lhs_rank = lhs_shape->rank(); const int64 rhs_rank = rhs_shape->rank(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; - if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { const bool should_broadcast_lhs = lhs_rank < rhs_rank; XlaOp from = should_broadcast_lhs ? lhs : rhs; @@ -576,13 +561,35 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, AddBroadcastSequence(shape, updated_rhs)); } - return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs, + direction); + }); +} + +XlaOp XlaBuilder::BinaryOpNoBroadcast( + HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs, + absl::optional direction) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + if (binop == HloOpcode::kCompare) { + if (!direction.has_value()) { + return InvalidArgument( + "kCompare expects a ComparisonDirection, but none provided."); + } + instr.set_comparison_direction(ComparisonDirectionToString(*direction)); + } else if (direction.has_value()) { + return InvalidArgument( + "A comparison direction is provided for a non-compare opcode: %s.", + HloOpcodeString(binop)); + } + + return AddInstruction(std::move(instr), binop, {lhs, rhs}); }); } XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; @@ -635,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { "%s Input scalar shapes may have been changed to non-scalar shapes.", status_or_shape.status().error_message()); } - *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto(); - return AddInstruction(std::move(instr), triop, + + return AddOpWithShape(triop, status_or_shape.ValueOrDie(), {updated_lhs, updated_rhs, updated_ehs}); }); } @@ -749,8 +756,9 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - auto output_shape = - ShapeUtil::MakeShape(operand_shape->element_type(), out_dim_size); + TF_ASSIGN_OR_RETURN(auto output_shape, + ShapeUtil::MakeValidatedShape( + operand_shape->element_type(), out_dim_size)); if (operand_shape->rank() != broadcast_dimensions.size()) { return InvalidArgument( "Size of broadcast_dimensions has to match operand's rank; operand " @@ -1616,12 +1624,10 @@ XlaOp XlaBuilder::Sort(absl::Span operands, XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + return AddOpWithShape(HloOpcode::kConvert, shape, {operand}); }); } @@ -2805,6 +2811,13 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, return op; } +StatusOr XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), opcode, operands); +} + void XlaBuilder::AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr) { absl::flat_hash_map remapped_ids; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index dc5c83e0bfb..9d03141715f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -95,6 +95,7 @@ class XlaOp { int64 handle() const { return handle_; } friend class XlaBuilder; + friend class MlirHloBuilder; // < 0 means "invalid handle". int64 handle_; @@ -139,7 +140,7 @@ class XlaBuilder { XlaBuilder(const XlaBuilder&) = delete; XlaBuilder& operator=(const XlaBuilder&) = delete; - ~XlaBuilder(); + virtual ~XlaBuilder(); // Returns the computation name. const string& name() const { return name_; } @@ -277,7 +278,7 @@ class XlaBuilder { StatusOr GetShape(XlaOp op) const; // Returns the shape of the given op. - StatusOr GetShapePtr(XlaOp op) const; + virtual StatusOr GetShapePtr(XlaOp op) const; // Returns the (inferred) result for the current computation's shape. This // assumes the root instruction is the last added instruction. @@ -645,7 +646,7 @@ class XlaBuilder { StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. - XlaOp UnaryOp(HloOpcode unop, XlaOp operand); + virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); // Internal helper method that does the building for an arbitrary binary op. // broadcast_dimensions specifies which dimensions to use for broadcasting @@ -655,14 +656,21 @@ class XlaBuilder { absl::Span broadcast_dimensions, absl::optional direction = absl::nullopt); + // Internal helper method that does the building for an arbitrary binary op + // with same ranked operands that doesn't broadcast. + virtual XlaOp BinaryOpNoBroadcast( + HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs, + absl::optional direction); + // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs); XlaOp RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape); - StatusOr InDimBroadcast(const Shape& shape, XlaOp operand, - absl::Span broadcast_dimensions); + virtual StatusOr InDimBroadcast( + const Shape& shape, XlaOp operand, + absl::Span broadcast_dimensions); // Internal helper method that creates a sequence of instructions that // performs an explicit broadcast of the operand to the target shape. @@ -671,8 +679,8 @@ class XlaBuilder { // Internal helper method for creating a Reshape op with the already inferred // shape. - StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, - int64 inferred_dimension = -1); + virtual StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension); // Returns the (inferred) result for the program shape using the given root. StatusOr GetProgramShape(int64 root_id) const; @@ -1056,15 +1064,20 @@ class XlaBuilder { friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); + protected: + // Returns OK status if the given op was built using this builder. Otherwise, + // returns an error. + Status CheckOpBuilder(XlaOp op) const; + private: XlaOp ConditionalImpl( XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands); - // Returns OK status if the given op was built using this builder. Otherwise, - // returns an error. - Status CheckOpBuilder(XlaOp op) const; + // Creates an op with the given opcode and the output shape. + virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands); // Here, InstructionType is either const HloInstructionProto* or non-const // HloInstructionProto*. diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index fd227ea47f2..115a822b323 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -329,6 +329,17 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { op::Broadcast(op::Reshape(op::Broadcast()))); } +TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); + BroadcastInDim(x, {-3, 3, 4}, + /*broadcast_dimensions=*/{0, 1, 2}); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("shape's dimensions must not be < 0")); +} + TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { XlaBuilder b1("b1"); auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index a0f60408296..1228ad527e3 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -30,6 +30,8 @@ std::string RunId::ToString() const { return "RunId: " + std::to_string(data_); } +int64 RunId::ToInt() const { return data_; } + ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index b44d5f13b68..6981b35975f 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -55,6 +55,7 @@ class RunId { RunId& operator=(const RunId&) = default; friend bool operator==(const RunId& a, const RunId& b); std::string ToString() const; + int64 ToInt() const; template friend H AbslHashValue(H h, const RunId& id) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 6c7aff3b11e..44e6a3c7bdb 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -133,8 +133,9 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } else if (shape.IsArray()) { if (allocate_arrays) { // Literals can be used as DMA targets, which can require alignment. We - // force a 16-byte minimum alignment. - constexpr int kMinimumAlignment = 16; + // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum + // alignment. + constexpr int kMinimumAlignment = 64; piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( piece->size_bytes(), kMinimumAlignment))); } diff --git a/tensorflow/compiler/xla/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc index f967a788dec..e3552470f63 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -143,6 +143,9 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) { string stdout_str; int child_status = child.Communicate(nullptr, &stdout_str, nullptr); CHECK_EQ(child_status, 0) << "test " << i; + // On windows, we get CR characters. Remove them. + stdout_str.erase(std::remove(stdout_str.begin(), stdout_str.end(), '\r'), + stdout_str.end()); CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i; } } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index d6c1a034859..3c93ec96113 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,7 +1,10 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") -load("//tensorflow:tensorflow.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pybind_extension") package( default_visibility = ["//tensorflow:internal"], @@ -59,6 +62,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":bfloat16", + ":local_client", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -165,6 +169,7 @@ cc_library( name = "local_client", srcs = ["local_client.cc"], hdrs = ["local_client.h"], + visibility = ["//tensorflow/compiler/xla:friends"], deps = [ ":local_device_state", ":shared_device_buffer", @@ -181,6 +186,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:allocator", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", @@ -292,15 +298,19 @@ cc_library( name = "nvidia_gpu_device", srcs = ["nvidia_gpu_device.cc"], hdrs = ["nvidia_gpu_device.h"], + copts = if_cuda(["-DNCCL_ENABLED=1"]), deps = [ ":local_client", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/python/distributed:client", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla:util", "//tensorflow/core:bfc_allocator", "//tensorflow/core:gpu_mem_allocator", "//tensorflow/stream_executor:tf_allocator_adapter", - ], + ] + if_cuda(["@local_config_nccl//:nccl"]), ) config_setting( @@ -355,6 +365,9 @@ pybind_extension( "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", + "//tensorflow/compiler/xla/python/distributed", + "//tensorflow/compiler/xla/python/distributed:client", + "//tensorflow/compiler/xla/python/distributed:service", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo", @@ -383,3 +396,25 @@ pybind_extension( "//conditions:default": [], }), ) + +tf_cc_test( + name = "gpu_multistream_test", + srcs = ["gpu_multistream_test.cc"], + tags = [ + # TODO(phawkins): figure out TF test infra such that this only runs under GPU. + "no_oss", + "requires-gpu-nvidia", + ], + deps = [ + ":local_client", + ":nvidia_gpu_device", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:random", + ], +) diff --git a/tensorflow/compiler/xla/python/cpu_device.cc b/tensorflow/compiler/xla/python/cpu_device.cc index 6bb17f12b89..6b55eac0c08 100644 --- a/tensorflow/compiler/xla/python/cpu_device.cc +++ b/tensorflow/compiler/xla/python/cpu_device.cc @@ -37,21 +37,21 @@ StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutor* executor = client->backend().stream_executor(i).ValueOrDie(); auto device_state = absl::make_unique( executor, client, /*synchronous_deallocation=*/true, asynchronous, /*allow_event_reuse=*/false); - std::shared_ptr device = - std::make_shared(i, std::move(device_state)); + auto device = absl::make_unique(i, std::move(device_state)); devices.push_back(std::move(device)); } return std::make_shared( kCpuPlatformName, client, std::move(devices), /*host_id=*/0, - /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr); + /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, + /*gpu_run_options=*/nullptr); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/BUILD b/tensorflow/compiler/xla/python/distributed/BUILD new file mode 100644 index 00000000000..b38084c3395 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/BUILD @@ -0,0 +1,122 @@ +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) + +package(default_visibility = ["//tensorflow:internal"]) + +tf_proto_library_cc( + name = "protocol_proto", + srcs = ["protocol.proto"], + has_services = 1, + cc_api_version = 2, + cc_grpc_version = 1, + use_grpc_namespace = True, +) + +cc_library( + name = "protocol", + hdrs = ["protocol.h"], +) + +cc_library( + name = "key_value_store", + srcs = ["key_value_store.cc"], + hdrs = ["key_value_store.h"], + deps = [ + "//tensorflow:grpc++", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "service", + srcs = ["service.cc"], + hdrs = ["service.h"], + deps = [ + ":key_value_store", + ":protocol", + ":protocol_proto_cc", + ":util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/time", + ], +) + +tf_cc_test( + name = "service_test", + srcs = ["service_test.cc"], + deps = [ + ":protocol_proto_cc", + ":service", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "client", + srcs = [ + "client.cc", + ], + hdrs = [ + "client.h", + ], + deps = [ + ":protocol", + ":protocol_proto_cc", + ":util", + "//tensorflow:grpc++", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "util", + hdrs = ["util.h"], + deps = [ + "//tensorflow:grpc++", + "//tensorflow/compiler/xla:status", + ], +) + +cc_library( + name = "distributed", + srcs = ["distributed.cc"], + hdrs = ["distributed.h"], + deps = [ + ":client", + ":service", + "//tensorflow:grpc++", + "//tensorflow/compiler/xla:statusor", + ], +) + +tf_cc_test( + name = "client_server_test", + srcs = ["client_server_test.cc"], + deps = [ + ":client", + ":protocol_proto_cc", + ":service", + "//tensorflow:grpc++", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/compiler/xla/python/distributed/client.cc b/tensorflow/compiler/xla/python/distributed/client.cc new file mode 100644 index 00000000000..c50c3f50a9d --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/client.cc @@ -0,0 +1,82 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/distributed/client.h" + +#include // NOLINT + +#include "tensorflow/compiler/xla/python/distributed/protocol.h" +#include "tensorflow/compiler/xla/python/distributed/util.h" + +namespace xla { + +DistributedRuntimeClient::DistributedRuntimeClient( + std::shared_ptr<::grpc::Channel> channel) + : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))) {} +DistributedRuntimeClient::~DistributedRuntimeClient() = default; + +xla::Status DistributedRuntimeClient::Connect( + const LocalTopologyProto& local_topology, + GlobalTopologyProto* global_topology) { + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + ctx.set_deadline(absl::ToChronoTime(absl::Now() + rpc_timeout_)); + ConnectRequest request; + request.set_protocol_version(kDistributedRuntimeProtocolVersion); + *request.mutable_local_topology() = local_topology; + VLOG(10) << "Connect: " << request.DebugString(); + ConnectResponse response; + ::grpc::Status status = stub_->Connect(&ctx, request, &response); + if (!status.ok()) { + return FromGrpcStatus(status); + } + VLOG(10) << "Connect() response: " << response.DebugString(); + response.mutable_global_topology()->Swap(global_topology); + return xla::Status::OK(); +} + +xla::StatusOr DistributedRuntimeClient::BlockingKeyValueGet( + std::string key, absl::Duration timeout) { + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout)); + KeyValueGetRequest request; + request.set_key(std::move(key)); + timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow + request.set_timeout_milliseconds(timeout / absl::Milliseconds(1)); + VLOG(10) << "BlockingKeyValueGet: " << request.DebugString(); + KeyValueGetResponse response; + ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response); + if (!status.ok()) { + return FromGrpcStatus(status); + } + return response.value(); +} + +xla::Status DistributedRuntimeClient::KeyValueSet(std::string key, + std::string value) { + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + ctx.set_deadline(absl::ToChronoTime(absl::Now() + rpc_timeout_)); + KeyValueSetRequest request; + request.set_key(std::move(key)); + request.set_value(std::move(value)); + VLOG(10) << "KeyValueSet: " << request.DebugString(); + KeyValueSetResponse response; + ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response); + return FromGrpcStatus(status); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/client.h b/tensorflow/compiler/xla/python/distributed/client.h new file mode 100644 index 00000000000..1ab5292bea8 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/client.h @@ -0,0 +1,50 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ + +#include + +#include "grpcpp/channel.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +class DistributedRuntimeClient { + public: + explicit DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel); + ~DistributedRuntimeClient(); + + xla::Status Connect(const LocalTopologyProto& local_topology, + GlobalTopologyProto* global_topology); + + xla::StatusOr BlockingKeyValueGet(std::string key, + absl::Duration timeout); + + xla::Status KeyValueSet(std::string key, std::string value); + + private: + const std::unique_ptr stub_; + const absl::Duration rpc_timeout_ = absl::Seconds(120); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/distributed/client_server_test.cc b/tensorflow/compiler/xla/python/distributed/client_server_test.cc new file mode 100644 index 00000000000..e78949933a2 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/client_server_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "grpcpp/security/server_credentials.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace xla { +namespace { + +TEST(ClientServerTest, ConnectToServer) { + DistributedRuntimeServiceImpl service(/*num_nodes=*/2); + ::grpc::ServerBuilder builder; + builder.RegisterService(&service); + auto server = builder.BuildAndStart(); + + std::vector locals(2); + locals[0].set_node_id(0); + locals[1].set_node_id(1); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[0].add_devices(); + d2->set_local_device_ordinal(707); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + GlobalTopologyProto expected_topology; + auto* node0 = expected_topology.add_nodes(); + auto* node1 = expected_topology.add_nodes(); + *node0 = locals[0]; + node0->mutable_devices(0)->set_global_device_id(0); + node0->mutable_devices(1)->set_global_device_id(1); + node0->mutable_devices(2)->set_global_device_id(2); + *node1 = locals[1]; + node1->mutable_devices(0)->set_global_device_id(3); + + auto thread0_fn = [&]() -> xla::Status { + DistributedRuntimeClient client( + server->InProcessChannel(::grpc::ChannelArguments())); + GlobalTopologyProto topology; + TF_RETURN_IF_ERROR(client.Connect(locals[0], &topology)); + TF_RET_CHECK( + xla::protobuf_util::ProtobufEquals(topology, expected_topology)); + TF_RETURN_IF_ERROR(client.KeyValueSet("key1", "value1")); + TF_ASSIGN_OR_RETURN( + std::string value, + client.BlockingKeyValueGet("key2", absl::InfiniteDuration())); + TF_RET_CHECK(value == "value2"); + return xla::Status::OK(); + }; + auto thread1_fn = [&]() -> xla::Status { + DistributedRuntimeClient client( + server->InProcessChannel(::grpc::ChannelArguments())); + GlobalTopologyProto topology; + TF_RETURN_IF_ERROR(client.Connect(locals[1], &topology)); + TF_RET_CHECK( + xla::protobuf_util::ProtobufEquals(topology, expected_topology)); + TF_ASSIGN_OR_RETURN( + std::string value, + client.BlockingKeyValueGet("key1", absl::InfiniteDuration())); + TF_RET_CHECK(value == "value1"); + TF_RETURN_IF_ERROR(client.KeyValueSet("key2", "value2")); + return xla::Status::OK(); + }; + + std::vector> functions = {thread0_fn, + thread1_fn}; + std::vector statuses(functions.size()); + { + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), "test_threads", functions.size()); + for (int i = 0; i < functions.size(); ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = functions[i](); }); + } + } + TF_EXPECT_OK(statuses[0]); + TF_EXPECT_OK(statuses[1]); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/distributed.cc b/tensorflow/compiler/xla/python/distributed/distributed.cc new file mode 100644 index 00000000000..6afc7b1c4e9 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/distributed.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/distributed/distributed.h" + +#include "grpcpp/grpcpp.h" + +namespace xla { + +StatusOr> +GetDistributedRuntimeService(std::string address, int num_nodes) { + auto credentials = ::grpc::InsecureServerCredentials(); + return DistributedRuntimeService::Get(address, credentials, num_nodes); +} + +std::shared_ptr GetDistributedRuntimeClient( + std::string address) { + std::shared_ptr<::grpc::ChannelCredentials> creds = + ::grpc::InsecureChannelCredentials(); + std::shared_ptr<::grpc::Channel> channel = + ::grpc::CreateChannel(address, creds); + return absl::make_unique(channel); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/distributed.h b/tensorflow/compiler/xla/python/distributed/distributed.h new file mode 100644 index 00000000000..0475c3e9feb --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/distributed.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ + +#include +#include + +#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// APIs for starting the distributed runtime service and client. Note that these +// variants use insecure credentials; the functions to build the service and +// client are kept separate so that other implementations using more secure +// credentials may be provided by the user. + +// Builds a distributed runtime service. `address` is the address on which +// the service should listen, e.g., [::]:1234 . `num_nodes` is the number +// of nodes in the cluster. +StatusOr> +GetDistributedRuntimeService(std::string address, int num_nodes); + +// Builds a distributed runtime client, connecting to a service at `address`, +// where address is a gRPC-style address such as `dns:///localhost:1234`. +std::shared_ptr GetDistributedRuntimeClient( + std::string address); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.cc b/tensorflow/compiler/xla/python/distributed/key_value_store.cc new file mode 100644 index 00000000000..5966d4ce12b --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/key_value_store.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" + +namespace xla { + +KeyValueStore::KeyValueStore() = default; + +::grpc::Status KeyValueStore::Get(const std::string& key, + absl::Duration timeout, std::string* value) { + auto key_is_present = [&]() { + mu_.AssertHeld(); + return entries_.find(key) != entries_.end(); + }; + absl::MutexLock lock(&mu_); + // TODO(phawkins): the synchronization here is very coarse, but probably + // sufficient for its current application. + if (!mu_.AwaitWithTimeout(absl::Condition(&key_is_present), timeout)) { + return ::grpc::Status(::grpc::StatusCode::NOT_FOUND, ""); + } + *value = entries_.find(key)->second; + return ::grpc::Status::OK; +} + +::grpc::Status KeyValueStore::Set(const std::string& key, std::string value) { + absl::MutexLock lock(&mu_); + entries_[key] = std::move(value); + return ::grpc::Status::OK; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.h b/tensorflow/compiler/xla/python/distributed/key_value_store.h new file mode 100644 index 00000000000..8560305e6f6 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/key_value_store.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ + +#include "grpcpp/grpcpp.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" + +namespace xla { + +// A simple blocking key-value store class. +class KeyValueStore { + public: + KeyValueStore(); + + KeyValueStore(const KeyValueStore&) = delete; + KeyValueStore(KeyValueStore&&) = delete; + KeyValueStore& operator=(const KeyValueStore&) = delete; + KeyValueStore&& operator=(KeyValueStore&&) = delete; + + // Looks up `key`. If present, returns its value. If the key is not present, + // waits until `timeout` expires for the key to arrive. If the key does not + // arrive by the expiry of `timeout`, returns NOT_FOUND. + ::grpc::Status Get(const std::string& key, absl::Duration timeout, + std::string* value); + + // Replaces the value of `key` with `value`. + ::grpc::Status Set(const std::string& key, std::string value); + + private: + absl::Mutex mu_; + absl::flat_hash_map entries_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ diff --git a/tensorflow/lite/python/optimize/sparsification_wrapper.i b/tensorflow/compiler/xla/python/distributed/protocol.h similarity index 69% rename from tensorflow/lite/python/optimize/sparsification_wrapper.i rename to tensorflow/compiler/xla/python/distributed/protocol.h index d7db2854bc2..208c6dab8c5 100644 --- a/tensorflow/lite/python/optimize/sparsification_wrapper.i +++ b/tensorflow/compiler/xla/python/distributed/protocol.h @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%include "std_string.i" +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +namespace xla { -%{ -#define SWIG_FILE_WITH_INIT -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/python/optimize/sparsification_wrapper.h" -%} +static constexpr int kDistributedRuntimeProtocolVersion = 1; +} // namespace xla -%include "tensorflow/lite/python/optimize/sparsification_wrapper.h" +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.proto b/tensorflow/compiler/xla/python/distributed/protocol.proto new file mode 100644 index 00000000000..18bfa221110 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/protocol.proto @@ -0,0 +1,103 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== +// +// Distributed XLA service protocol. +// +// This is a minimal distributed protocol intended primarily for sharing NCCL +// communicator state between distributed hosts. +// +// The intention is to replace this with a more capable distributed runtime at +// some point in the near future, but this suffices for simple multihost GPU +// use cases. +// +// The intention is that a service is started during cluster initialization and +// persists for the lifetime of the cluster. +// +// TODO(phawkins): add a health-checking mechanism. + +syntax = "proto3"; + +package xla; + +// Describes a device local to a host. +message DeviceProto { + int32 local_device_ordinal = 1; + string name = 2; + string vendor = 3; + + // The following fields are present in the GlobalTopologyProto message + // returned by Connect() but not in the LocalTopologyProto messages passed to + // Connect(). In other words, the master node determines the global device IDs + // during Connect(). + int32 global_device_id = 4; // Globally unique ID number. +} + +// Describes the set of devices local to a host. +message LocalTopologyProto { + // We assume that each node knows its globally-unique node ID, provided by + // whatever mechanism launches the tasks. Node IDs should form a dense range + // of integers [0, num_nodes). + int32 node_id = 1; + repeated DeviceProto devices = 2; +} + +message GlobalTopologyProto { + repeated LocalTopologyProto nodes = 1; +} + +message ConnectRequest { + int32 protocol_version = 1; // Always 1 at present. + + LocalTopologyProto local_topology = 2; +} + +message ConnectResponse { + GlobalTopologyProto global_topology = 2; +} + +message KeyValueGetRequest { + bytes key = 1; + int32 timeout_milliseconds = 2; +} + +message KeyValueGetResponse { + bool found = 1; + bytes value = 2; +} + +message KeyValueSetRequest { + bytes key = 1; + bytes value = 2; +} + +message KeyValueSetResponse {} + +service DistributedRuntimeService { + // Connects a node to the distributed master node. Blocks until all workers + // have connected. The service receives the number of nodes to expect as an + // option passed to its constructor. + rpc Connect(ConnectRequest) returns (ConnectResponse) {} + + // Simple key-value store used for sharing configuration data. + // For example, when using NCCL to communicate between multiple GPUs, + // the NCCL communicator IDs are stored here. + + // Looks up a key in the key-value service. Blocks until the key is present + // or until `timeout` expires. + rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {} + + // Updates the value associated with a key. + rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {} +} diff --git a/tensorflow/compiler/xla/python/distributed/service.cc b/tensorflow/compiler/xla/python/distributed/service.cc new file mode 100644 index 00000000000..cc2b3a5aca2 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/service.cc @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/distributed/service.h" + +#include "tensorflow/compiler/xla/python/distributed/protocol.h" +#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(int num_nodes) { + nodes_.resize(num_nodes); + local_topologies_.resize(num_nodes); +} + +// Steals the contents of `local_topologies`. +void BuildGlobalTopology(absl::Span local_topologies, + GlobalTopologyProto* global_topology) { + int next_global_device_id = 0; + for (LocalTopologyProto& local : local_topologies) { + for (DeviceProto& device : *local.mutable_devices()) { + device.set_global_device_id(next_global_device_id++); + } + global_topology->add_nodes()->Swap(&local); + } +} + +::grpc::Status DistributedRuntimeServiceImpl::Connect( + ::grpc::ServerContext* context, const ConnectRequest* request, + ConnectResponse* response) { + VLOG(10) << "Connect " << request->DebugString(); + if (request->protocol_version() != kDistributedRuntimeProtocolVersion) { + return ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d", + request->protocol_version())); + } + absl::MutexLock lock(&mu_); + if (state_ != State::kInitializing) { + return ToGrpcStatus(xla::FailedPrecondition( + "Connect() called when system is not initializing.")); + } + int node_id = request->local_topology().node_id(); + if (node_id < 0 || node_id >= nodes_.size()) { + return ToGrpcStatus( + xla::InvalidArgument("Invalid node ID %d, must be in the range [0, %d)", + node_id, nodes_.size())); + } + if (nodes_[node_id].present) { + return ToGrpcStatus(xla::InvalidArgument("Duplicate node ID %d", node_id)); + } + nodes_[node_id].present = true; + local_topologies_[node_id] = request->local_topology(); + ++num_nodes_present_; + + auto all_nodes_present = [&]() { + mu_.AssertHeld(); + return num_nodes_present_ == nodes_.size(); + }; + if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_present), + kConnectTimeout)) { + return ToGrpcStatus(tensorflow::errors::DeadlineExceeded( + "Timed out after %s waiting for all nodes to call Connect()", + absl::FormatDuration(kConnectTimeout))); + } + + if (node_id == 0) { + BuildGlobalTopology(absl::Span(local_topologies_), + &topology_); + local_topologies_.clear(); + state_ = State::kRunning; + } else { + auto running = [&]() { + mu_.AssertHeld(); + return state_ == State::kRunning; + }; + mu_.Await(absl::Condition(&running)); + } + *response->mutable_global_topology() = topology_; + return ::grpc::Status::OK; +} + +::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet( + ::grpc::ServerContext* context, const KeyValueGetRequest* request, + KeyValueGetResponse* response) { + VLOG(10) << "KeyValueGet " << request->DebugString(); + { + absl::MutexLock lock(&mu_); + if (state_ != State::kRunning) { + return ToGrpcStatus(xla::FailedPrecondition( + "KeyValueGet() called when system is not running.")); + } + } + return key_value_store_.Get( + request->key(), absl::Milliseconds(request->timeout_milliseconds()), + response->mutable_value()); +} + +::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet( + ::grpc::ServerContext* context, const KeyValueSetRequest* request, + KeyValueSetResponse* response) { + VLOG(10) << "KeyValueSet " << request->DebugString(); + { + absl::MutexLock lock(&mu_); + if (state_ != State::kRunning) { + return ToGrpcStatus(xla::FailedPrecondition( + "KeyValueSet() called when system is not running; clients must call " + "Connect() first")); + } + } + return key_value_store_.Set(request->key(), request->value()); +} + +xla::StatusOr> +DistributedRuntimeService::Get( + const std::string& address, + std::shared_ptr<::grpc::ServerCredentials> credentials, int num_nodes) { + auto service = absl::make_unique(num_nodes); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address, credentials); + VLOG(1) << "Distributed runtmie service address " << address; + builder.RegisterService(&service->impl_); + service->server_ = builder.BuildAndStart(); + if (!service->server_) { + return xla::Unknown("Failed to start RPC server"); + } + LOG(INFO) << "Jax service listening on " << address; + return service; +} + +DistributedRuntimeService::DistributedRuntimeService(int num_nodes) + : impl_(num_nodes) {} + +DistributedRuntimeService::~DistributedRuntimeService() { + if (server_) { + LOG(INFO) << "Jax service shutting down"; + server_->Shutdown(); + server_->Wait(); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/service.h b/tensorflow/compiler/xla/python/distributed/service.h new file mode 100644 index 00000000000..baf470e4f13 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/service.h @@ -0,0 +1,101 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ + +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +typedef int NodeId; + +class DistributedRuntimeServiceImpl final + : public grpc::DistributedRuntimeService::Service { + public: + explicit DistributedRuntimeServiceImpl(int num_nodes); + + DistributedRuntimeServiceImpl(const DistributedRuntimeServiceImpl&) = delete; + DistributedRuntimeServiceImpl(DistributedRuntimeServiceImpl&&) = delete; + DistributedRuntimeServiceImpl& operator=( + const DistributedRuntimeServiceImpl&) = delete; + DistributedRuntimeServiceImpl&& operator=(DistributedRuntimeServiceImpl&&) = + delete; + + ::grpc::Status Connect(::grpc::ServerContext* context, + const ConnectRequest* request, + ConnectResponse* response) override; + + ::grpc::Status KeyValueGet(::grpc::ServerContext* context, + const KeyValueGetRequest* request, + KeyValueGetResponse* response) override; + + ::grpc::Status KeyValueSet(::grpc::ServerContext* context, + const KeyValueSetRequest* request, + KeyValueSetResponse* response) override; + + private: + const absl::Duration kConnectTimeout = absl::Seconds(120); + + absl::Mutex mu_; + enum class State { kInitializing, kRunning }; + State state_ GUARDED_BY(mu_) = State::kInitializing; + + std::vector local_topologies_ GUARDED_BY(mu_); + GlobalTopologyProto topology_ GUARDED_BY(mu_); + struct Node { + bool present = false; + }; + int num_nodes_present_ GUARDED_BY(mu_) = 0; + std::vector nodes_ GUARDED_BY(mu_); + + KeyValueStore key_value_store_; +}; + +class DistributedRuntimeService { + public: + static xla::StatusOr> Get( + const std::string& address, + std::shared_ptr<::grpc::ServerCredentials> credentials, int num_nodes); + + explicit DistributedRuntimeService(int num_nodes); + ~DistributedRuntimeService(); + + DistributedRuntimeService(const DistributedRuntimeService&) = delete; + DistributedRuntimeService(DistributedRuntimeService&&) = delete; + DistributedRuntimeService& operator=(const DistributedRuntimeService&) = + delete; + DistributedRuntimeService& operator=(DistributedRuntimeService&&) = delete; + + ::grpc::Server* server() const { return server_.get(); } + + private: + DistributedRuntimeServiceImpl impl_; + std::unique_ptr<::grpc::Server> server_; +}; + +// Everything below this point is exposed only for tests. + +// Given a LocalTopologyProto object from each node, builds a +// GlobalTopologyProto that describes all nodes. +void BuildGlobalTopology(absl::Span local_topologies, + GlobalTopologyProto* global_topology); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/service_test.cc b/tensorflow/compiler/xla/python/distributed/service_test.cc new file mode 100644 index 00000000000..08326df2f38 --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/service_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/distributed/service.h" + +#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(TopologyTest, BuildGlobalTopology) { + std::vector locals(2); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + GlobalTopologyProto global; + BuildGlobalTopology(absl::Span(locals), &global); + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/distributed/util.h b/tensorflow/compiler/xla/python/distributed/util.h new file mode 100644 index 00000000000..07ae8d1f0ce --- /dev/null +++ b/tensorflow/compiler/xla/python/distributed/util.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ + +#include "grpcpp/support/status.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +inline Status FromGrpcStatus(const ::grpc::Status& s) { + if (s.ok()) { + return Status::OK(); + } else { + return Status(static_cast(s.error_code()), + s.error_message()); + } +} + +inline ::grpc::Status ToGrpcStatus(const Status& s) { + if (s.ok()) { + return ::grpc::Status::OK; + } else { + return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), + s.error_message()); + } +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index b4ae503ba4c..ca34fb504bd 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -210,8 +210,8 @@ StatusOr DLContextForDevice(const Device& device) { return context; } -StatusOr> DeviceForDLContext( - const PyLocalClient& client, const DLContext& context) { +StatusOr DeviceForDLContext(const PyLocalClient& client, + const DLContext& context) { se::Platform::Id platform_id; switch (context.device_type) { case kDLCPU: @@ -224,13 +224,11 @@ StatusOr> DeviceForDLContext( return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); } - auto it = absl::c_find_if( - client.local_devices(), [&](const std::shared_ptr& device) { - return device->local_device_state()->executor()->platform()->id() == - platform_id && - device->local_device_state()->device_ordinal() == - context.device_id; - }); + auto it = absl::c_find_if(client.local_devices(), [&](Device* device) { + return device->local_device_state()->executor()->platform()->id() == + platform_id && + device->local_device_state()->device_ordinal() == context.device_id; + }); if (it == client.local_devices().end()) { return InvalidArgument( "No matching device found for DLPack device_type %d device_id %d", @@ -289,7 +287,7 @@ StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { } StatusOr> DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, std::shared_ptr client) { + const pybind11::capsule& tensor, PyLocalClient* client) { if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " @@ -302,7 +300,7 @@ StatusOr> DLPackManagedTensorToBuffer( "Number of dimensions in DLManagedTensor must be nonnegative, got %d", dlmt->dl_tensor.ndim); } - TF_ASSIGN_OR_RETURN(std::shared_ptr device, + TF_ASSIGN_OR_RETURN(Device * device, DeviceForDLContext(*client, dlmt->dl_tensor.ctx)); absl::Span dimensions( reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); @@ -329,19 +327,19 @@ StatusOr> DLPackManagedTensorToBuffer( if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } + absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id, std::initializer_list{buffer}, /*children=*/std::vector>{}, - /*definition_event=*/nullptr, std::move(on_delete_callback)); + definition_events, std::move(on_delete_callback)); // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. PyCapsule_SetName(tensor.ptr(), "used_dltensor"); PyCapsule_SetDestructor(tensor.ptr(), nullptr); - return absl::make_unique(shape, shape, - std::move(device_buffer), - std::move(client), std::move(device)); + return absl::make_unique( + shape, shape, std::move(device_buffer), client, device); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 92eba687225..09700841ab4 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -24,7 +24,7 @@ namespace xla { StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer); StatusOr> DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, std::shared_ptr client); + const pybind11::capsule& tensor, PyLocalClient* client); } // namespace xla diff --git a/tensorflow/compiler/xla/python/event_pool.h b/tensorflow/compiler/xla/python/event_pool.h index 56787acd87e..f858b5edef8 100644 --- a/tensorflow/compiler/xla/python/event_pool.h +++ b/tensorflow/compiler/xla/python/event_pool.h @@ -68,7 +68,7 @@ class EventPool { const bool allow_reuse_; absl::Mutex mu_; - std::stack> free_events_ GUARDED_BY(mu_); + std::stack> free_events_ TF_GUARDED_BY(mu_); }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/gpu_multistream_test.cc b/tensorflow/compiler/xla/python/gpu_multistream_test.cc new file mode 100644 index 00000000000..a633e4dd020 --- /dev/null +++ b/tensorflow/compiler/xla/python/gpu_multistream_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/platform/random.h" + +namespace xla { +namespace { + +// Regression test that verifies that substreams of a multistream GPU +// computation wait for the inputs to be produced before executing. +TEST(GpuMultiStream, Basics) { + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr client, + GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), + /*distributed_client=*/nullptr, /*node_id=*/0)); + + Device* device = client->local_devices().at(0); + + int n = 1024; + Shape shape = ShapeUtil::MakeShape(S32, {n}); + std::vector inputs(n); + std::vector expected_outputs(n); + + XlaBuilder builder("acomputation"); + auto p0 = Parameter(&builder, 0, shape, "param"); + auto p1 = Parameter(&builder, 1, shape, "param"); + Tuple(&builder, {Neg(p0), Neg(p1)}); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + + CompileOptions compile_options; + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_disable_multi_streaming(false); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_use_random_streams(true); + DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device->id(); + compile_options.executable_build_options.set_device_assignment( + device_assignment); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + PyLocalExecutable::Compile(computation, client.get(), + std::move(compile_options))); + + int64 dummy_size = 1 << 20; + std::vector dummy_inputs(dummy_size); + Shape dummy_shape = ShapeUtil::MakeShape(S32, {dummy_size}); + + for (int i = 0; i < 100; ++i) { + for (int i = 0; i < n; ++i) { + inputs[i] = tensorflow::random::New64(); + expected_outputs[i] = -inputs[i]; + } + // Transfer a large dummy buffer, behind which the inputs to the computation + // must wait. + TF_ASSERT_OK_AND_ASSIGN( + auto dummy_buffer, + PyLocalBuffer::FromHostBuffer( + dummy_inputs.data(), dummy_shape, /*force_copy=*/false, + /*buffer_reference=*/nullptr, client.get(), device)); + TF_ASSERT_OK_AND_ASSIGN( + auto in_buffer0, + PyLocalBuffer::FromHostBuffer( + inputs.data(), shape, /*force_copy=*/false, + /*buffer_reference=*/nullptr, client.get(), device)); + TF_ASSERT_OK_AND_ASSIGN( + auto in_buffer1, + PyLocalBuffer::FromHostBuffer( + inputs.data(), shape, /*force_copy=*/false, + /*buffer_reference=*/nullptr, client.get(), device)); + // The execution may be enqueued before the transfers complete, requiring + // adequate device-side synchronization. + ExecuteOptions options; + options.untuple_result = true; + TF_ASSERT_OK_AND_ASSIGN( + auto out_buffers, + executable->Execute({in_buffer0.get(), in_buffer1.get()}, options)); + + TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral()); + LiteralTestUtil::ExpectR1Equal(expected_outputs, *out_literal); + TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[1]->ToLiteral()); + LiteralTestUtil::ExpectR1Equal(expected_outputs, *out_literal); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index a35b20f6aa1..c721f3bea8b 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -95,6 +95,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -111,17 +112,55 @@ std::string Device::DebugString() const { return absl::StrCat(platform_name(), ":", id()); } +StatusOr DevicesToDeviceAssignment( + absl::Span> devices) { + if (devices.empty()) { + return InvalidArgument( + "Device assignment passed to Compile() must be non-empty."); + } + if (devices[0].empty()) { + return InvalidArgument( + "Device assignment passed to Compile() must have a nonzero number of " + "partitions per replica; replica 0 had 0 partitions."); + } + DeviceAssignment xla_assignment(devices.size(), devices[0].size()); + for (int replica = 0; replica < devices.size(); ++replica) { + if (devices[replica].size() != devices[0].size()) { + return InvalidArgument( + "Device assignment passed to Compile() has different numbers of " + "partitions between replicas; %d partitions for replica %d versus %d " + "partitions for replica 0.", + devices[replica].size(), replica, devices[0].size()); + } + for (int partition = 0; partition < devices[replica].size(); ++partition) { + if (devices[0][0]->platform_name() != + devices[replica][partition]->platform_name()) { + return InvalidArgument( + "Device assignment passed to Compile() must have devices of a " + "single kind, got %s for replica 0 partition 0 and %s for replica " + "%d partition %d.", + devices[0][0]->platform_name(), + devices[replica][partition]->platform_name(), replica, partition); + } + xla_assignment(replica, partition) = devices[replica][partition]->id(); + } + } + return xla_assignment; +} + PyLocalClient::PyLocalClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, - std::unique_ptr host_memory_allocator) + std::unique_ptr host_memory_allocator, + std::unique_ptr gpu_run_options) : platform_name_(std::move(platform_name)), client_(client), devices_(std::move(devices)), host_id_(host_id), owned_allocator_(std::move(allocator)), host_memory_allocator_(std::move(host_memory_allocator)), + gpu_run_options_(std::move(gpu_run_options)), h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", client->device_count()) { if (owned_allocator_ != nullptr) { @@ -130,8 +169,8 @@ PyLocalClient::PyLocalClient( allocator_ = client_->backend().memory_allocator(); } - for (const std::shared_ptr& device : devices_) { - CHECK(id_to_device_.insert({device->id(), device}).second) + for (const std::unique_ptr& device : devices_) { + CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); if (device->local_device_state()) { @@ -140,7 +179,7 @@ PyLocalClient::PyLocalClient( local_devices_.resize(idx + 1); } CHECK(local_devices_[idx] == nullptr) << idx; - local_devices_[idx] = device; + local_devices_[idx] = device.get(); } } for (int idx = 0; idx < local_devices_.size(); ++idx) { @@ -157,8 +196,8 @@ StatusOr PyLocalClient::GetDefaultDeviceAssignment( /* static */ StatusOr> PyLocalBuffer::FromHostBuffer( const void* data, const Shape& shape, bool force_copy, - std::shared_ptr buffer_reference, - std::shared_ptr client, std::shared_ptr device) { + std::shared_ptr buffer_reference, PyLocalClient* client, + Device* device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString() << " device: " << device->DebugString(); @@ -180,14 +219,14 @@ StatusOr> PyLocalBuffer::FromHostBuffer( }; se::DeviceMemoryBase buffer(const_cast(data), ShapeUtil::ByteSizeOf(shape)); + absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, local_device->device_ordinal(), std::initializer_list{buffer}, /*children=*/std::vector>{}, - /*definition_event=*/nullptr, std::move(on_delete_callback)); + definition_events, std::move(on_delete_callback)); return absl::make_unique( - shape, shape, std::move(device_buffer), std::move(client), - std::move(device)); + shape, shape, std::move(device_buffer), client, device); } TransferManager* transfer_manager = @@ -216,7 +255,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( std::make_shared(); std::shared_ptr device_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, - definition_event); + {definition_event}); Shape on_device_shape = scoped_buffer.on_device_shape(); auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, @@ -261,7 +300,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( // Sets the buffer definition event. Note: this has the side effect of // unblocking any host threads that may have been waiting to consume the // buffer. - device_buffer->definition_event()->SetDefinitionEvent( + device_buffer->definition_events()[0]->SetDefinitionEvent( std::move(event), local_device->host_to_device_stream()); if (local_device->synchronous_deallocation()) { @@ -276,12 +315,12 @@ StatusOr> PyLocalBuffer::FromHostBuffer( client->h2d_transfer_pool()->Schedule(transfer_h2d); return absl::make_unique( compact_shape, std::move(on_device_shape), std::move(device_buffer), - std::move(client), std::move(device)); + client, device); } /* static */ StatusOr> PyLocalBuffer::MakeTuple( - const std::vector buffers, - std::shared_ptr client, std::shared_ptr device) { + absl::Span buffers, PyLocalClient* client, + Device* device) { TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); std::vector host_shapes; @@ -291,7 +330,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( device_shapes.reserve(buffers.size()); device_buffers.reserve(buffers.size()); for (const PyLocalBuffer* buffer : buffers) { - if (buffer->device().get() != device.get()) { + if (buffer->device() != device) { return InvalidArgument( "Tuple elements must be on the same device; %s vs %s", buffer->device()->DebugString(), device->DebugString()); @@ -316,7 +355,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( std::shared_ptr tuple_buffer, SharedDeviceBuffer::MakeTuple( device_buffers, on_host_shape, transfer_manager, allocator, - local_device->device_ordinal(), definition_event)); + local_device->device_ordinal(), {definition_event})); auto buffer = absl::make_unique( std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes), tuple_buffer, std::move(client), std::move(device)); @@ -346,14 +385,84 @@ StatusOr> PyLocalBuffer::FromHostBuffer( return buffer; } +StatusOr>> +MakeCrossHostReceiveBuffersHelper(absl::Span shapes, + PyLocalClient* client, Device* device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + std::vector> buffers; + buffers.reserve(shapes.size()); + se::Stream* host_to_device_stream = local_device->host_to_device_stream(); + for (const auto& shape : shapes) { + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer scoped_buffer, + transfer_manager->AllocateScopedShapedBuffer( + shape, client->allocator(), local_device->device_ordinal())); + + if (!transfer_manager->CanShapedBufferBeAccessedNow( + local_device->compute_stream()->parent(), scoped_buffer)) { + return Unimplemented( + "Cross host receive not enabled unless deallocations are deferred"); + } + + absl::InlinedVector, 2> + definition_events; + + if (scoped_buffer.on_device_shape().IsTuple()) { + TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( + host_to_device_stream, scoped_buffer)); + definition_events = {std::make_shared(), + std::make_shared()}; + TF_ASSIGN_OR_RETURN(EventPool::Handle event, + local_device->event_pool().ThenAllocateAndRecordEvent( + host_to_device_stream)); + definition_events[1]->SetDefinitionEvent(std::move(event), + host_to_device_stream); + } else { + definition_events = {std::make_shared()}; + } + + std::shared_ptr device_buffer = + SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, + definition_events); + Shape on_device_shape = scoped_buffer.on_device_shape(); + + auto buffer = absl::make_unique( + shape, std::move(on_device_shape), std::move(device_buffer), client, + device); + + buffers.push_back(std::move(buffer)); + } + return buffers; +} + +/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers( + absl::Span shapes, PyLocalClient* client, Device* device, + PyLocalCrossHostRecvNotifier&& notifier) { + if (shapes.empty()) { + notifier(InvalidArgument( + "shapes parameter empty in MakeCrossHostReceiveBuffers")); + return; + } + auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, client, device); + if (!buffer_or.ok()) { + notifier(buffer_or.status()); + return; + } + + client->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(), + std::move(notifier)); +} + PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - std::shared_ptr client, - std::shared_ptr device) - : client_(std::move(client)), + PyLocalClient* client, Device* device) + : client_(client), on_host_shape_(std::move(on_host_shape)), on_device_shape_(std::move(on_device_shape)), - device_(std::move(device)), + device_(device), device_buffer_(std::move(device_buffer)) {} void PyLocalBuffer::Delete() { @@ -425,7 +534,7 @@ StatusOr PyLocalBuffer::AsShapedBuffer() const { } StatusOr>> -PyLocalBuffer::DestructureTuple() { +PyLocalBuffer::DestructureTuple() const { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple"); absl::MutexLock lock(&mu_); if (!on_host_shape_.IsTuple()) { @@ -449,13 +558,13 @@ PyLocalBuffer::DestructureTuple() { } StatusOr> PyLocalBuffer::CopyToDevice( - std::shared_ptr dst_device) { + Device* dst_device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); std::shared_ptr src_device_buffer = DeviceBuffer(); TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, dst_device->GetLocalDeviceState()); - if (dst_device.get() == device_.get()) { + if (dst_device == device_) { return absl::make_unique( on_host_shape_, on_device_shape_, src_device_buffer, client_, device_); } @@ -488,9 +597,11 @@ StatusOr> PyLocalBuffer::CopyToDevice( TF_RET_CHECK(input_buffer.size() == output_buffer.size()) << "input: " << input_buffer.size() << " output: " << output_buffer.size(); - TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( - transfer_stream, dst_local_device->compute_stream(), input_buffer, - output_buffer)); + if (input_buffer.size() != 0) { + TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( + transfer_stream, dst_local_device->compute_stream(), input_buffer, + output_buffer)); + } } // We hold on to the `src_device_buffer` until the transfer is finished. @@ -517,12 +628,18 @@ StatusOr> PyLocalBuffer::CopyToDevice( definition_event->SetDefinitionEvent(std::move(event), transfer_stream); std::shared_ptr dst_device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, definition_event); + SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, + {definition_event}); return absl::make_unique( dst_buffer.on_host_shape(), dst_buffer.on_device_shape(), std::move(dst_device_buffer), client_, dst_device); } +Status PyLocalBuffer::CopyToRemoteDevice( + absl::string_view serialized_descriptor, Device* dst_device) { + return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device); +} + Status PyLocalBuffer::BlockHostUntilReady() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); std::shared_ptr device_buffer = DeviceBuffer(); @@ -541,8 +658,7 @@ Status PyLocalBuffer::BlockHostUntilReady() { return stream->BlockHostUntilDone(); } -static std::shared_ptr LookupDevice(const PyLocalClient& client, - int device_id) { +static Device* LookupDevice(const PyLocalClient& client, int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -551,8 +667,8 @@ static std::shared_ptr LookupDevice(const PyLocalClient& client, PyLocalExecutable::PyLocalExecutable( std::vector> executables, - DeviceAssignment device_assignment, std::shared_ptr client) - : client_(std::move(client)), + DeviceAssignment device_assignment, PyLocalClient* client) + : client_(client), device_assignment_( std::make_shared(device_assignment)) { executables_.reserve(executables.size()); @@ -577,7 +693,7 @@ PyLocalExecutable::PyLocalExecutable( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment_)(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + Device* device = LookupDevice(*client_, device_id); if (device->host_id() != client_->host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; @@ -602,14 +718,27 @@ const std::string& PyLocalExecutable::name() const { } } -StatusOr> PyLocalExecutable::ExecuteHelper( +StatusOr>> +PyLocalExecutable::ExecuteHelper( absl::Span argument_handles, int replica, - int partition, const RunId& run_id) { + int partition, const RunId& run_id, const ExecuteOptions& options) const { const int device_id = (*device_assignment_)(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + Device* device = LookupDevice(*client_, device_id); + + std::unique_ptr tuple_buffer; + std::vector tupled_arguments; + if (options.tuple_arguments) { + TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple( + argument_handles, client_, device)); + tupled_arguments = {tuple_buffer.get()}; + argument_handles = tupled_arguments; + } CHECK_EQ(device->host_id(), client_->host_id()); int device_ordinal = device->local_device_state()->device_ordinal(); - tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); + tensorflow::profiler::TraceMe traceme([&] { + return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), + "#"); + }); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; @@ -628,7 +757,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( "Deleted buffer passed to Execute() as argument %d to replica %d", i, replica); } - if (handle->device().get() != device.get()) { + if (handle->device() != device) { return InvalidArgument( "Buffer passed to Execute() as argument %d to replica %d is on " "device %s, but replica is assigned to device %s.", @@ -649,15 +778,16 @@ StatusOr> PyLocalExecutable::ExecuteHelper( event->WaitForEventOnStream(device_state->compute_stream()); } - ExecutableRunOptions options; - options.set_stream(device_state->compute_stream()); - options.set_host_to_device_stream(device_state->host_to_device_stream()); - options.set_allocator(client_->allocator()); - options.set_intra_op_thread_pool( + ExecutableRunOptions run_options; + run_options.set_stream(device_state->compute_stream()); + run_options.set_host_to_device_stream(device_state->host_to_device_stream()); + run_options.set_allocator(client_->allocator()); + run_options.set_intra_op_thread_pool( client_->client()->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(device_assignment_.get()); - options.set_run_id(run_id); - options.set_rng_seed(device_state->GetNewPrngSeed()); + run_options.set_device_assignment(device_assignment_.get()); + run_options.set_run_id(run_id); + run_options.set_rng_seed(device_state->GetNewPrngSeed()); + run_options.set_gpu_executable_run_options(client_->gpu_run_options()); // The choice of where we wait is arbitrary; the reason for the wait is pacing // to avoid problems such as memory fragmentation and running ahead too far, @@ -670,7 +800,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( int executable_idx = executables_.size() > 1 ? partition : 0; StatusOr result_buffer_or_status = - executables_[executable_idx]->RunAsync(argument_buffer_ptrs, options); + executables_[executable_idx]->RunAsync(argument_buffer_ptrs, run_options); VLOG(1) << "Replica " << replica << " partition " << partition << " completed; ok=" << result_buffer_or_status.ok(); @@ -690,7 +820,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( std::shared_ptr out_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer, - definition_event); + {definition_event}); if (device_state->synchronous_deallocation()) { device_buffers.push_back(out_buffer); @@ -702,13 +832,19 @@ StatusOr> PyLocalExecutable::ExecuteHelper( device_state->compute_stream(), std::make_tuple(executables_[executable_idx], compute_reservation, device_assignment_)); - return absl::make_unique( + std::vector> outputs; + outputs.push_back(absl::make_unique( result_buffer.on_host_shape(), result_buffer.on_device_shape(), - std::move(out_buffer), client_, device); + std::move(out_buffer), client_, device)); + if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { + TF_ASSIGN_OR_RETURN(outputs, outputs.front()->DestructureTuple()); + } + return outputs; } -StatusOr> PyLocalExecutable::Execute( - absl::Span argument_handles) { +StatusOr>> +PyLocalExecutable::Execute(absl::Span argument_handles, + const ExecuteOptions& options) const { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute()", @@ -721,27 +857,18 @@ StatusOr> PyLocalExecutable::Execute( } VLOG(1) << "Executing computation " << name(); return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0, - RunId()); + RunId(), options); } -StatusOr>> -PyLocalExecutable::ExecutePerReplica( - absl::Span> argument_handles) { - tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica"); - if (num_partitions() != 1) { - return InvalidArgument( - "Attempted to execute computation with %d partitions using " - "ExecutePerReplica()", - num_partitions()); - } - return ExecuteOnLocalDevices(argument_handles); -} - -StatusOr>> +StatusOr>>> PyLocalExecutable::ExecuteOnLocalDevices( - absl::Span> argument_handles) { - tensorflow::profiler::TraceMe traceme( - "LocalExecutable::ExecuteOnLocalDevices"); + absl::Span> argument_handles, + const ExecuteOptions& options) const { + RunId run_id; + tensorflow::profiler::TraceMe traceme([&] { + return absl::StrCat( + "LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#"); + }); const int num_local_devices = local_devices_.size(); @@ -755,9 +882,9 @@ PyLocalExecutable::ExecuteOnLocalDevices( VLOG(1) << "Executing computation " << name() << "; num_replicas=" << num_replicas() - << " num_partitions=" << num_partitions() - << " num_local_devices=" << num_local_devices; - std::vector>> results( + << " num_partitions=" << num_partitions() << " num_local_devices=8" + << num_local_devices; + std::vector>>> results( num_local_devices); if (num_local_devices == 1) { // Fast-path if there is only one device — run the computation on the @@ -765,9 +892,8 @@ PyLocalExecutable::ExecuteOnLocalDevices( const int replica = local_logical_device_ids_[0].first; const int partition = local_logical_device_ids_[0].second; results[0] = - ExecuteHelper(argument_handles[0], replica, partition, RunId()); + ExecuteHelper(argument_handles[0], replica, partition, run_id, options); } else { - RunId run_id; absl::Mutex mu; int running = num_local_devices; int failed = 0; @@ -776,11 +902,11 @@ PyLocalExecutable::ExecuteOnLocalDevices( for (int i = 0; i < num_local_devices; ++i) { const int replica = local_logical_device_ids_[i].first; const int partition = local_logical_device_ids_[i].second; - std::shared_ptr device = local_devices_[i]; + Device* device = local_devices_[i]; const LocalDeviceState& device_state = *device->local_device_state(); device_state.execute_thread()->Schedule([&, replica, partition, i] { - results[i] = - ExecuteHelper(argument_handles[i], replica, partition, run_id); + results[i] = ExecuteHelper(argument_handles[i], replica, partition, + run_id, options); absl::MutexLock lock(&mu); --running; @@ -821,7 +947,7 @@ PyLocalExecutable::ExecuteOnLocalDevices( } VLOG(1) << "Replicated execution complete."; - std::vector> wrapped_results( + std::vector>> wrapped_results( num_local_devices); for (int i = 0; i < num_local_devices; ++i) { const int replica = local_logical_device_ids_[i].first; @@ -840,107 +966,37 @@ PyLocalExecutable::ExecuteOnLocalDevices( return wrapped_results; } -/*static*/ StatusOr> -PyLocalExecutable::CompileForDevices( - const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - const std::vector>>& - device_assignment) { - if (device_assignment.empty()) { - return InvalidArgument( - "Device assignment passed to Compile() must be non-empty."); - } - if (device_assignment[0].empty()) { - return InvalidArgument( - "Device assignment passed to Compile() must have a nonzero number of " - "partitions per replica; replica 0 had 0 partitions."); - } - DeviceAssignment xla_assignment(device_assignment.size(), - device_assignment[0].size()); - for (int replica = 0; replica < device_assignment.size(); ++replica) { - if (device_assignment[replica].size() != device_assignment[0].size()) { - return InvalidArgument( - "Device assignment passed to Compile() has different numbers of " - "partitions between replicas; %d partitions for replica %d versus %d " - "partitions for replica 0.", - device_assignment[replica].size(), replica, - device_assignment[0].size()); - } - for (int partition = 0; partition < device_assignment[replica].size(); - ++partition) { - if (device_assignment[0][0]->platform_name() != - device_assignment[replica][partition]->platform_name()) { - return InvalidArgument( - "Device assignment passed to Compile() must have devices of a " - "single kind, got %s for replica 0 partition 0 and %s for replica " - "%d partition %d.", - device_assignment[0][0]->platform_name(), - device_assignment[replica][partition]->platform_name(), replica, - partition); - } - xla_assignment(replica, partition) = - device_assignment[replica][partition]->id(); - } - } - return Compile(computation, std::move(argument_layouts), build_options, - std::move(client), xla_assignment); -} - /*static*/ StatusOr> PyLocalExecutable::Compile(const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional device_assignment) { + PyLocalClient* client, CompileOptions options) { tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); - ExecutableBuildOptions options; - if (build_options != nullptr) { - options = *build_options; + ExecutableBuildOptions& build_options = options.executable_build_options; + if (!build_options.device_allocator()) { + build_options.set_device_allocator(client->allocator()); } - if (!options.device_allocator()) { - options.set_device_allocator(client->allocator()); + if (!build_options.has_device_assignment()) { + VLOG(2) << "PyLocalExecutable::Compile using default device_assignment."; + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client->GetDefaultDeviceAssignment(build_options.num_replicas(), + build_options.num_partitions())); + build_options.set_device_assignment(device_assignment); } + VLOG(2) << "PyLocalExecutable::Compile device_assignment:\n" + << build_options.device_assignment().ToString(); - if (device_assignment) { - VLOG(2) << "PyLocalExecutable::Compile got device_assignment:\n" - << device_assignment->ToString(); - if (device_assignment->replica_count() != options.num_replicas()) { - return InvalidArgument( - "Mismatched number of replicas for device " - "assignment and computation (%d vs %d).\n%s", - device_assignment->replica_count(), options.num_replicas(), - device_assignment->ToString()); - } - if (device_assignment->computation_count() != options.num_partitions()) { - return InvalidArgument( - "Mismatched number of partitions for device " - "assignment and computation (%d vs %d).\n%s", - device_assignment->computation_count(), options.num_partitions(), - device_assignment->ToString()); - } - } else { - TF_ASSIGN_OR_RETURN(device_assignment, - client->GetDefaultDeviceAssignment( - options.num_replicas(), options.num_partitions())); - VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n" - << device_assignment->ToString(); - } - options.set_device_assignment(device_assignment.value()); - - if (!argument_layouts) { + if (!options.argument_layouts) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); - argument_layouts = program_shape.parameters(); - for (Shape& shape : *argument_layouts) { + options.argument_layouts = program_shape.parameters(); + for (Shape& shape : *options.argument_layouts) { LayoutUtil::ClearLayout(&shape); } } std::vector argument_layout_pointers; - argument_layout_pointers.reserve(argument_layouts->size()); + argument_layout_pointers.reserve(options.argument_layouts->size()); // Assign a default layout to any array subshapes that are missing layouts. auto assign_layouts = [client](Shape* shape) { @@ -958,14 +1014,14 @@ PyLocalExecutable::Compile(const XlaComputation& computation, }); }; - for (Shape& layout : *argument_layouts) { + for (Shape& layout : *options.argument_layouts) { argument_layout_pointers.push_back(&layout); TF_RETURN_IF_ERROR(assign_layouts(&layout)); } Shape result_layout; - if (options.result_layout()) { - result_layout = *options.result_layout(); + if (build_options.result_layout()) { + result_layout = *build_options.result_layout(); } else { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); @@ -973,16 +1029,15 @@ PyLocalExecutable::Compile(const XlaComputation& computation, LayoutUtil::ClearLayout(&result_layout); } TF_RETURN_IF_ERROR(assign_layouts(&result_layout)); - options.set_result_layout(result_layout); + build_options.set_result_layout(result_layout); TF_ASSIGN_OR_RETURN( std::vector> local_executables, client->client()->Compile(computation, argument_layout_pointers, - options)); + build_options)); - return absl::make_unique(std::move(local_executables), - std::move(*device_assignment), - std::move(client)); + return absl::make_unique( + std::move(local_executables), build_options.device_assignment(), client); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 51c8df90786..401064af77c 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -30,13 +30,20 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_device_state.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" +// API notes: +// Despite having the name "PyLocalClient", it is intended that this API may +// also be consumed from C++. Python/pybind11/NumPy logic should therefore not +// be used in this API. + namespace xla { class Device { @@ -80,15 +87,32 @@ class Device { const std::string platform_name_; }; +class PyLocalBuffer; +// Helper struct for cross host transfers, returned by the callback from a call +// to PyLocalBuffer::MakeCrossHostReceiveBuffers. +struct PyLocalCrossHostRecvBuffer { + // serialized_descriptor should be transmitted to the sender and passed to a + // call to src_buffer->CopyToRemoteDevice. + std::string serialized_descriptor; + // The buffer that will hold the result of the transfer. + std::unique_ptr buffer; +}; +using PyLocalCrossHostRecvNotifier = + std::function>&&)>; + // Encapsulates the state of Python session with XLA. -class PyLocalClient { +// +// It is the responsibility of the client of this API to keep the PyLocalClient +// alive as long as any of the other runtime objects are alive. +class PyLocalClient : public std::enable_shared_from_this { public: // `allocator` may null, in which case the platform default allocator is used. explicit PyLocalClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, - std::unique_ptr host_memory_allocator); + std::unique_ptr host_memory_allocator, + std::unique_ptr gpu_run_options); virtual ~PyLocalClient() = default; virtual StatusOr GetDefaultDeviceAssignment( @@ -96,15 +120,11 @@ class PyLocalClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() const { + const std::vector>& devices() const { return devices_; } - const std::vector>& local_devices() const { - return local_devices_; - } - const std::map>& id_to_device() const { - return id_to_device_; - } + const std::vector& local_devices() const { return local_devices_; } + const std::map& id_to_device() const { return id_to_device_; } int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } @@ -118,6 +138,10 @@ class PyLocalClient { return host_memory_allocator_.get(); } + GpuExecutableRunOptions* gpu_run_options() const { + return gpu_run_options_.get(); + } + tensorflow::thread::ThreadPool* h2d_transfer_pool() { return &h2d_transfer_pool_; } @@ -128,15 +152,28 @@ class PyLocalClient { virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } protected: + friend class PyLocalBuffer; + virtual void EnqueueCrossHostReceive( + std::vector>&& buffers, + PyLocalCrossHostRecvNotifier&& notifier) const { + notifier(Unimplemented("Cross host receives not implemented.")); + } + + virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer, + absl::string_view serialized_descriptor, + Device* device) const { + return Unimplemented("Cross host sends not implemented."); + } + std::string platform_name_; LocalClient* client_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> devices_; // Maps Device::id() to the corresponding Device. Includes all devices. - std::map> id_to_device_; + std::map id_to_device_; // Local devices indexed by local device ordinal. - std::vector> local_devices_; + std::vector local_devices_; int host_id_; se::DeviceMemoryAllocator* allocator_; @@ -147,9 +184,16 @@ class PyLocalClient { // device via a staging area of pinned memory. std::unique_ptr host_memory_allocator_; + std::unique_ptr gpu_run_options_; + tensorflow::thread::ThreadPool h2d_transfer_pool_; }; +// Converts a 2D set of Device objects indexed by [replica][partition] into an +// xla::DeviceAssignment. +StatusOr DevicesToDeviceAssignment( + absl::Span> devices); + // Holds a reference from Python to one or more device buffers. // A PyLocalBuffer can be either valid or invalid. An invalid buffer is one that // has never been initialized, or a buffer that has been deleted (e.g., by @@ -166,17 +210,29 @@ class PyLocalBuffer { // the runtime (may be nullptr). static StatusOr> FromHostBuffer( const void* data, const Shape& shape, bool force_copy, - std::shared_ptr buffer_reference, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr buffer_reference, PyLocalClient* client, + Device* device); static StatusOr> MakeTuple( - const std::vector buffers, - std::shared_ptr client, std::shared_ptr device); + absl::Span buffers, PyLocalClient* client, + Device* device); + + // Asynchronously makes a vector of PyLocalBuffers that can be used to receive + // cross host transfers using `client` on `device'. `shapes` must be the exact + // shapes, with identical layouts, corresponding to the buffers that will be + // sent. When resources for the transfer are available, notifier will be + // called with a vector of PyLocalCrossHostRecvBuffer structs, one for each + // shape in `shapes`. Each struct contains a buffer that will contain the + // received value, and an opaque string that should be transmitted to the + // sending host and used in a call to CopyToRemoteDevice. None of the recv + // buffers will become ready until *all* of the sends have completed. + static void MakeCrossHostReceiveBuffers( + absl::Span shapes, PyLocalClient* client, Device* device, + PyLocalCrossHostRecvNotifier&& notifier); PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - std::shared_ptr client, - std::shared_ptr device); + PyLocalClient* client, Device* device); PyLocalBuffer(const PyLocalBuffer&) = delete; PyLocalBuffer(PyLocalBuffer&&) = delete; @@ -185,9 +241,9 @@ class PyLocalBuffer { const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } - std::shared_ptr device() const { return device_; } + Device* device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } - std::shared_ptr client() const { return client_; } + PyLocalClient* client() const { return client_; } // Returns the buffer's value as a tuple DAG of Python arrays. If the value // has previously been prefetched to the host, then returns the prefetched @@ -213,23 +269,35 @@ class PyLocalBuffer { StatusOr AsShapedBuffer() const; // Destructures a tuple-valued PyLocalBuffer into its constituent elements. - StatusOr>> DestructureTuple(); + StatusOr>> DestructureTuple() + const; // Copies the buffer to device `dst_device`. - StatusOr> CopyToDevice( - std::shared_ptr dst_device); + StatusOr> CopyToDevice(Device* dst_device); + + // Copies the buffer to remote device `dst_device`. This call must be preceded + // by a call to MakeCrossHostReceiveBuffers on the remote host's + // dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to + // construct the destination buffers, and a callback supplies an array + // containing both the destination buffers, and a serialized descriptor for + // each buffer. For each destination buffer there should be a matching call to + // src->CopyToRemoteDevice on a remote host for a src buffer of the + // corresponding shape. serialized_descriptor is the string returned by the + // callback along with the corresponding destination buffer. + Status CopyToRemoteDevice(absl::string_view serialized_descriptor, + Device* dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. Status BlockHostUntilReady(); private: - const std::shared_ptr client_; + PyLocalClient* const client_; const Shape on_host_shape_; const Shape on_device_shape_; - const std::shared_ptr device_; + Device* const device_; mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ GUARDED_BY(mu_); + std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); // The cached value of the buffer on the host, produced either from a call to // CopyToHost or from a call to ToLiteral. Once a value has been fetched to @@ -241,7 +309,25 @@ class PyLocalBuffer { Status status; std::shared_ptr value; }; - std::shared_ptr host_value_ GUARDED_BY(mu_); + std::shared_ptr host_value_ TF_GUARDED_BY(mu_); +}; + +struct CompileOptions { + // The layouts of the arguments that the computation should expect. + absl::optional> argument_layouts; + + // XLA's compilation time options. + ExecutableBuildOptions executable_build_options; +}; + +struct ExecuteOptions { + // If true, the arguments to the computation will be wrapped in a tuple and + // passed as a single parameter. + bool tuple_arguments = false; + + // If true, the computation must return a tuple, which will be destructured + // into its elements. + bool untuple_result = false; }; // Represents a compiled computation that can be executed given handles to @@ -249,27 +335,14 @@ class PyLocalBuffer { // partition, as specified by the build options). class PyLocalExecutable { public: - // Compiles a computation to an executable. - static StatusOr> CompileForDevices( - const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - const std::vector>>& - device_assignment); - - // TODO(phawkins): Deprecated. Delete once all callers have been updated to - // use the newer form. static StatusOr> Compile( - const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional device_assignment); + const XlaComputation& computation, PyLocalClient* client, + CompileOptions options); PyLocalExecutable(std::vector> executables, - DeviceAssignment device_assignment, - std::shared_ptr client); + DeviceAssignment device_assignment, PyLocalClient* client); + + PyLocalClient* client() const { return client_; } int num_replicas() const { return executables_[0]->build_options().num_replicas(); @@ -299,41 +372,34 @@ class PyLocalExecutable { return local_logical_device_ids_; } - const std::vector>& local_devices() const { - return local_devices_; - } + const std::vector& local_devices() const { return local_devices_; } - StatusOr> Execute( - absl::Span argument_handles); - - // Execute on many replicas. Takes a sequence of argument lists (one argument - // list per replica) and returns a tuple of results (one result per replica). - // The number of argument lists must be equal to the replica count. - // The executable must have only one partition. - // TODO(cjfj): Remove this once JAX is moved to `ExecuteOnLocalDevices`. - StatusOr>> ExecutePerReplica( - absl::Span> argument_handles); + StatusOr>> Execute( + absl::Span argument_handles, + const ExecuteOptions& options) const; // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device // count. - StatusOr>> ExecuteOnLocalDevices( - absl::Span> argument_handles); + StatusOr>>> + ExecuteOnLocalDevices( + absl::Span> argument_handles, + const ExecuteOptions& options) const; void Delete() { executables_.clear(); } const string& name() const; private: - StatusOr> ExecuteHelper( + StatusOr>> ExecuteHelper( absl::Span argument_handles, int replica, - int partition, const RunId& run_id); + int partition, const RunId& run_id, const ExecuteOptions& options) const; // Create shared pointers so we can free them after the execution: with // asynchronous execution, the process being executed can outlive the // executable itself. - std::shared_ptr const client_; + PyLocalClient* const client_; // One executable per partition. std::vector> executables_; std::shared_ptr device_assignment_; @@ -350,7 +416,7 @@ class PyLocalExecutable { // assigned. // shared_ptrs instead of unique_ptrs to play well with the Python bindings // (see xla.cc). - std::vector> local_devices_; + std::vector local_devices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/python/local_device_state.h index a64176294e0..fa73c832c57 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/python/local_device_state.h @@ -129,12 +129,12 @@ class LocalDeviceState { static constexpr int kNumDeviceToDeviceStreams = 4; absl::Mutex mu_; - int next_device_to_host_stream_ GUARDED_BY(mu_) = 0; - int next_device_to_device_stream_ GUARDED_BY(mu_) = 0; + int next_device_to_host_stream_ TF_GUARDED_BY(mu_) = 0; + int next_device_to_device_stream_ TF_GUARDED_BY(mu_) = 0; - std::random_device prng_seed_device_ GUARDED_BY(mu_); - std::mt19937 prng_seed_generator_ GUARDED_BY(mu_); - std::uniform_int_distribution<> prng_seed_distribution_ GUARDED_BY(mu_); + std::random_device prng_seed_device_ TF_GUARDED_BY(mu_); + std::mt19937 prng_seed_generator_ TF_GUARDED_BY(mu_); + std::uniform_int_distribution<> prng_seed_distribution_ TF_GUARDED_BY(mu_); // Callback stream is used for running short host-side callbacks after device // side events, without preventing the device-side stream from doing useful diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc b/tensorflow/compiler/xla/python/nvidia_gpu_device.cc index b7b2faef8d7..26ea727dee7 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/python/nvidia_gpu_device.cc @@ -15,28 +15,82 @@ limitations under the License. #include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#ifdef NCCL_ENABLED +#include "third_party/nccl/nccl.h" +#endif // NCCL_ENABLED #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" namespace xla { +namespace { static const char kGpuPlatformName[] = "gpu"; -GpuDevice::GpuDevice(int id, - std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kGpuPlatformName) {} +// A custom PyLocalClient that overrides the device assignment method. +class GpuClient : public xla::PyLocalClient { + public: + using xla::PyLocalClient::PyLocalClient; -static StatusOr> CreateBFCAllocator( - se::Platform* platform, - absl::Span> local_devices, - LocalClient* client, double memory_fraction, bool preallocate) { - CHECK_GT(client->backend().device_count(), 0); + xla::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; +}; + +xla::StatusOr GpuClient::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + // XLA:GPU does not support multiple partitions yet. + TF_RET_CHECK(num_partitions == 1) << num_partitions; + if (num_replicas <= local_devices().size()) { + xla::DeviceAssignment assignment(num_replicas, 1); + for (int i = 0; i < num_replicas; ++i) { + assignment(i, 0) = local_devices().at(i)->id(); + } + return assignment; + } + // Fallback to default global device assignment if we can't run locally. + return PyLocalClient::GetDefaultDeviceAssignment(num_replicas, + num_partitions); +} + +// Builds an xla::LocalClient for the GPU platform. +StatusOr GetGpuXlaClient() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform("CUDA")); + if (platform->VisibleDeviceCount() <= 0) { + return FailedPrecondition("No visible NVidia GPU devices."); + } + LocalClientOptions options; + options.set_platform(platform); + return ClientLibrary::GetOrCreateLocalClient(options); +} + +// Builds a LocalDeviceState for each GPU present. +StatusOr>> BuildLocalDeviceStates( + LocalClient* xla_client, bool asynchronous) { + std::vector> local_devices; + for (int i = 0; i < xla_client->device_count(); ++i) { + se::StreamExecutor* executor = + xla_client->backend().stream_executor(i).ValueOrDie(); + local_devices.push_back(absl::make_unique( + executor, xla_client, /*synchronous_deallocation=*/false, asynchronous, + /*allow_event_reuse=*/true)); + } + return std::move(local_devices); +} + +// Builds a BFCAllocator for all local GPUs. +StatusOr> CreateBFCAllocator( + absl::Span const> local_devices, + double memory_fraction, bool preallocate) { + CHECK_GT(local_devices.size(), 0); + const se::Platform* platform = local_devices.front()->executor()->platform(); std::vector allocators; - for (se::StreamExecutor* executor : client->backend().stream_executors()) { + for (auto& local_device : local_devices) { + se::StreamExecutor* executor = local_device->executor(); int device_ordinal = executor->device_ordinal(); auto sub_allocator = absl::make_unique( executor, tensorflow::PlatformGpuId(device_ordinal), @@ -65,60 +119,201 @@ static StatusOr> CreateBFCAllocator( /*allow_growth=*/!preallocate, absl::StrCat("GPU_", device_ordinal, "_bfc")); allocators.emplace_back(std::move(gpu_bfc_allocator), - local_devices.at(device_ordinal) - ->local_device_state() - ->compute_stream()); + local_device->compute_stream()); } return absl::make_unique(platform, std::move(allocators)); } -StatusOr> GetNvidiaGpuClient( - bool asynchronous, const GpuAllocatorConfig& allocator_config) { - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform("CUDA")); - if (platform->VisibleDeviceCount() <= 0) { - return FailedPrecondition("No visible NVidia GPU devices."); - } - LocalClientOptions options; - options.set_platform(platform); - TF_ASSIGN_OR_RETURN(LocalClient * client, - ClientLibrary::GetOrCreateLocalClient(options)); - - std::vector> devices; - for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client->backend().stream_executor(i).ValueOrDie(); - auto device_state = absl::make_unique( - executor, client, /*synchronous_deallocation=*/false, asynchronous, - /*allow_event_reuse=*/true); - std::shared_ptr device = - std::make_shared(i, std::move(device_state)); - devices.push_back(std::move(device)); - } - +// Constructs a GPU device memory allocator to use, according to the allocator +// configuration the client requested. +StatusOr> GetGpuDeviceAllocator( + const GpuAllocatorConfig& allocator_config, + absl::Span const> local_devices) { std::unique_ptr allocator; - std::unique_ptr host_memory_allocator; if (allocator_config.kind != GpuAllocatorConfig::Kind::kPlatform) { - TF_ASSIGN_OR_RETURN(allocator, - CreateBFCAllocator(platform, devices, client, - allocator_config.memory_fraction, - allocator_config.preallocate)); + TF_ASSIGN_OR_RETURN( + allocator, + CreateBFCAllocator(local_devices, allocator_config.memory_fraction, + allocator_config.preallocate)); } + return std::move(allocator); +} +// Returns a GPU pinned host memory allocator to use when staging host->GPU +// transfers. We use a fixed 64MB pool of pinned memory. +std::unique_ptr GetGpuHostAllocator( + se::StreamExecutor* executor) { tensorflow::SubAllocator* sub_allocator = new tensorflow::GpuHostAllocator( - client->backend().stream_executor(0).ValueOrDie(), /*numa_node=*/0, - /*alloc_visitors=*/{}, - /*free_visitors=*/{}); + executor, /*numa_node=*/0, /*alloc_visitors=*/{}, /*free_visitors=*/{}); // TODO(phawkins): allow the user to tune this. const int64 kGpuHostMemoryLimitBytes = 64 * (1LL << 30); - host_memory_allocator = absl::make_unique( + return absl::make_unique( sub_allocator, kGpuHostMemoryLimitBytes, /*allow_growth=*/true, /*name=*/"xla_gpu_host_bfc"); +} - return std::make_shared("gpu", client, std::move(devices), - /*host_id=*/0, std::move(allocator), - std::move(host_memory_allocator)); +// A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings. +// In a distributed setup the table of NCCL IDs is kept on the master node +// (node 0). Currently node 0 is the only node that generates ncclUniqueIds; +// see the TODO below. +class NcclIdStore { + public: + NcclIdStore(int node_id, std::shared_ptr client) + : node_id_(node_id), client_(std::move(client)) {} + + StatusOr GetNcclUniqueId(const NcclCliqueKey& key); + + private: + const int node_id_; + const std::shared_ptr client_; + + absl::Mutex mu_; + absl::flat_hash_map cache_ GUARDED_BY(mu_); +}; + +StatusOr NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { + std::string key_string = GlobalDeviceIdsToString(key.devices()); + { + absl::MutexLock lock(&mu_); + auto it = cache_.find(key_string); + if (it != cache_.end()) { + return it->second; + } + } + auto result = [&]() -> StatusOr { + // TODO(phawkins): this will deadlock if node 0 is not involved in the + // computation. Add support for computations that only use a subset of + // replicas. + if (node_id_ == 0) { +#ifdef NCCL_ENABLED + ncclUniqueId id; + ncclResult_t r = ncclGetUniqueId(&id); + TF_RET_CHECK(r == ncclSuccess); + std::string value(id.internal, NCCL_UNIQUE_ID_BYTES); + TF_RETURN_IF_ERROR(client_->KeyValueSet(key_string, value)); + return value; +#else + return FailedPrecondition("NCCL support was not built into XLA binary."); +#endif + } else { + return client_->BlockingKeyValueGet(key_string, absl::Minutes(5)); + } + }(); + if (!result.ok()) { + return result.status(); + } + absl::MutexLock lock(&mu_); + return cache_.emplace(key_string, result.ValueOrDie()).first->second; +} + +std::vector> BuildLocalDevices( + std::vector> local_device_states) { + std::vector> devices; + for (auto& local_device : local_device_states) { + int device_ordinal = local_device->device_ordinal(); + auto device = absl::make_unique( + device_ordinal, std::move(local_device), /*node_id=*/0); + devices.push_back(std::move(device)); + } + return devices; +} + +Status BuildDistributedDevices( + std::vector> local_device_states, + std::shared_ptr distributed_client, int node_id, + std::vector>* devices, + GpuExecutableRunOptions* gpu_executable_run_options) { + LocalTopologyProto local_topology; + local_topology.set_node_id(node_id); + for (const auto& local_device : local_device_states) { + const se::Platform* platform = local_device->executor()->platform(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr desc, + platform->DescriptionForDevice(local_device->device_ordinal())); + TF_RET_CHECK(local_device->device_ordinal() == + local_topology.devices_size()); + DeviceProto* device_proto = local_topology.add_devices(); + device_proto->set_local_device_ordinal(local_device->device_ordinal()); + device_proto->set_name(desc->name()); + device_proto->set_vendor(desc->device_vendor()); + } + + GlobalTopologyProto global_topology; + TF_RETURN_IF_ERROR( + distributed_client->Connect(local_topology, &global_topology)); + + std::vector gpu_device_ids(local_device_states.size()); + for (const LocalTopologyProto& node : global_topology.nodes()) { + for (const DeviceProto& device_proto : node.devices()) { + std::unique_ptr local_device; + if (node.node_id() == node_id) { + TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 && + device_proto.local_device_ordinal() < + local_device_states.size()); + TF_RET_CHECK(local_device_states[device_proto.local_device_ordinal()] != + nullptr); + local_device = + std::move(local_device_states[device_proto.local_device_ordinal()]); + gpu_device_ids[device_proto.local_device_ordinal()] = + GlobalDeviceId(device_proto.global_device_id()); + } + auto device = + absl::make_unique(device_proto.global_device_id(), + std::move(local_device), node.node_id()); + devices->push_back(std::move(device)); + } + } + for (const auto& device : local_device_states) { + TF_RET_CHECK(device == nullptr); + } + gpu_executable_run_options->set_gpu_global_device_ids( + std::move(gpu_device_ids)); + auto nccl_id_store = + std::make_shared(node_id, distributed_client); + gpu_executable_run_options->set_nccl_unique_id_callback( + [nccl_id_store](const NcclCliqueKey& key) { + return nccl_id_store->GetNcclUniqueId(key); + }); + return Status::OK(); +} + +} // namespace + +GpuDevice::GpuDevice(int id, + std::unique_ptr local_device_state, + int node_id) + : Device(id, std::move(local_device_state), kGpuPlatformName, node_id) {} + +StatusOr> GetNvidiaGpuClient( + bool asynchronous, const GpuAllocatorConfig& allocator_config, + std::shared_ptr distributed_client, int node_id) { + TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient()); + TF_ASSIGN_OR_RETURN( + std::vector> local_device_states, + BuildLocalDeviceStates(xla_client, asynchronous)); + TF_ASSIGN_OR_RETURN( + auto allocator, + GetGpuDeviceAllocator(allocator_config, local_device_states)); + auto host_memory_allocator = + GetGpuHostAllocator(local_device_states.front()->executor()); + + std::vector> devices; + auto gpu_run_options = absl::make_unique(); + if (distributed_client) { + TF_RETURN_IF_ERROR(BuildDistributedDevices( + std::move(local_device_states), std::move(distributed_client), node_id, + &devices, gpu_run_options.get())); + } else { + devices = BuildLocalDevices(std::move(local_device_states)); + } + + std::shared_ptr pyclient = std::make_shared( + "gpu", xla_client, std::move(devices), + /*node_id=*/node_id, std::move(allocator), + std::move(host_memory_allocator), + /*gpu_run_options=*/std::move(gpu_run_options)); + return pyclient; } } // namespace xla diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.h b/tensorflow/compiler/xla/python/nvidia_gpu_device.h index a89f8044d4f..333a82a2d78 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/python/nvidia_gpu_device.h @@ -18,14 +18,17 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/python/distributed/client.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/bfc_allocator.h" namespace xla { class GpuDevice : public Device { public: - GpuDevice(int id, std::unique_ptr local_device_state); + GpuDevice(int id, std::unique_ptr local_device_state, + int node_id); }; struct GpuAllocatorConfig { @@ -48,8 +51,11 @@ struct GpuAllocatorConfig { bool preallocate = true; }; +// distributed_client may be nullptr in non-distributed settings. +// distributed_client should not be Open()ed before calling this function. StatusOr> GetNvidiaGpuClient( - bool asynchronous, const GpuAllocatorConfig& allocator_config); + bool asynchronous, const GpuAllocatorConfig& allocator_config, + std::shared_ptr distributed_client, int node_id); } // namespace xla diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc index ca6da645024..91f2b434a61 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include #include #include "tensorflow/stream_executor/device_memory.h" @@ -60,7 +61,8 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( int device_ordinal, se::DeviceMemoryAllocator* allocator, ShapeTree::iterator* iterator, const ShapeTree::iterator& end, - const std::shared_ptr& definition_event) { + absl::Span> + definition_events) { std::vector buffers; buffers.reserve(1); std::vector> children; @@ -78,7 +80,7 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( for (int i = 0; i < num_children; ++i) { children.push_back(BufferFromScopedShapedBufferIterator( on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i), - device_ordinal, allocator, iterator, end, definition_event)); + device_ordinal, allocator, iterator, end, definition_events)); } } else { // An on-host array may be an on-device tuple. For example, a complex tensor @@ -88,20 +90,21 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( [&](const Shape&, const ShapeIndex&) { consume_buffer(); }); } return std::make_shared( - absl::Span(buffers), children, definition_event); + absl::Span(buffers), children, definition_events); } /* static */ std::shared_ptr SharedDeviceBuffer::FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - const std::shared_ptr& definition_event) { + absl::Span> + definition_events) { ShapeTree::iterator iterator = shaped_buffer->buffers().begin(); std::shared_ptr output = BufferFromScopedShapedBufferIterator( shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(), - &iterator, shaped_buffer->buffers().end(), definition_event); + &iterator, shaped_buffer->buffers().end(), definition_events); CHECK(iterator == shaped_buffer->buffers().end()); return output; } @@ -111,7 +114,8 @@ SharedDeviceBuffer::MakeTuple( std::vector> children, const Shape& on_host_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event) { + absl::Span> + definition_events) { CHECK(on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == children.size()); TF_ASSIGN_OR_RETURN( @@ -122,7 +126,7 @@ SharedDeviceBuffer::MakeTuple( return std::make_shared( allocator, device_ordinal, std::initializer_list{device_memory.Release()}, - std::move(children), std::move(definition_event), + std::move(children), definition_events, /*on_delete_callback=*/nullptr); } @@ -130,7 +134,8 @@ SharedDeviceBuffer::MakeTuple( SharedDeviceBuffer::MakeArray( Shape on_device_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event) { + absl::Span> + definition_events) { std::vector device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status { @@ -145,7 +150,7 @@ SharedDeviceBuffer::MakeArray( return std::make_shared( absl::Span(device_buffers), /*children=*/std::vector>{}, - std::move(definition_event)); + definition_events); } // Populates a buffer tree from a ShapeTree iterator. @@ -176,25 +181,36 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, return shaped_buffer; } +namespace { + +using MoveIterator = + absl::Span>::iterator; + +} // namespace + SharedDeviceBuffer::SharedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event, + absl::Span> definition_events, std::function on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), device_memory_(device_memory.begin(), device_memory.end()), children_(std::move(children)), - definition_event_(std::move(definition_event)), + definition_events_( + std::move_iterator(definition_events.begin()), + std::move_iterator(definition_events.end())), on_delete_callback_(std::move(on_delete_callback)) {} SharedDeviceBuffer::SharedDeviceBuffer( absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event) + absl::Span> definition_events) : children_(std::move(children)), - definition_event_(std::move(definition_event)) { + definition_events_( + std::move_iterator(definition_events.begin()), + std::move_iterator(definition_events.end())) { CHECK(!device_memory.empty()); allocator_ = device_memory.front().allocator(); device_ordinal_ = device_memory.front().device_ordinal(); @@ -222,8 +238,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() { void GetDeviceBufferDefinitionEvents( const SharedDeviceBuffer& buffer, absl::flat_hash_set* events) { - if (buffer.definition_event()) { - events->insert(buffer.definition_event().get()); + for (const auto& e : buffer.definition_events()) { + events->insert(e.get()); } for (const auto& child : buffer.children()) { GetDeviceBufferDefinitionEvents(*child, events); diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h index 8d9d8278d33..3aa122c535d 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -66,7 +66,7 @@ class BufferDefinitionEvent { void WaitForEventOnStream(se::Stream* stream); private: - bool EventHasBeenRecorded() EXCLUSIVE_LOCKS_REQUIRED(mu_); + bool EventHasBeenRecorded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // An event that is triggered when the content of one or more buffers is // ready. If this event is nullptr, it is assumed that the buffer's content is @@ -77,7 +77,7 @@ class BufferDefinitionEvent { // A list of all streams for which the buffer's content is known to be defined // at the tail of the queue, i.e., for any newly enqueued command. - absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ TF_GUARDED_BY(mu_); }; // Class that represents a node in a reference-counted DAG of device buffers. @@ -93,20 +93,23 @@ class SharedDeviceBuffer { // buffers of the shaped_buffer. static std::shared_ptr FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - const std::shared_ptr& definition_event); + absl::Span> + definition_events); // Makes a tuple buffer. Does not initialize the tuple table. static StatusOr> MakeTuple( std::vector> children, const Shape& on_host_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event); + absl::Span> + definition_events); // Makes an uninitialized array buffer. static StatusOr> MakeArray( Shape on_device_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event); + absl::Span> + definition_events); // Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do // not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) == @@ -126,19 +129,22 @@ class SharedDeviceBuffer { const absl::InlinedVector& device_memory() const { return device_memory_; } - const std::shared_ptr definition_event() const { - return definition_event_; + absl::Span> definition_events() + const { + return definition_events_; } SharedDeviceBuffer() = default; SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event, + absl::Span> + definition_events, std::function on_delete_callback); SharedDeviceBuffer(absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event); + absl::Span> + definition_events); ~SharedDeviceBuffer(); private: @@ -155,7 +161,8 @@ class SharedDeviceBuffer { // ready during multistream execution. May be nullptr, which is used in the // single-stream execution case where events are not necessary for buffer // event sequencing. - std::shared_ptr definition_event_; + absl::InlinedVector, 2> + definition_events_; // A callback to call when the SharedDeviceBuffer is about to be destroyed. std::function on_delete_callback_; diff --git a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc index b39767a0d46..05842c52a0c 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc @@ -28,10 +28,10 @@ TEST(SharedDeviceBufferTest, MakeArray) { LocalClient* client = ClientLibrary::LocalClientOrDie(); Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); - TF_ASSERT_OK_AND_ASSIGN( - auto buffer, SharedDeviceBuffer::MakeArray( - shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, + SharedDeviceBuffer::MakeArray( + shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); EXPECT_EQ(buffer->children().size(), 0); EXPECT_EQ(buffer->device_ordinal(), 0); EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator()); @@ -45,19 +45,19 @@ TEST(SharedDeviceBufferTest, MakeTuple) { Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); Shape b_shape = ShapeUtil::MakeShape(S8, {77}); Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); - TF_ASSERT_OK_AND_ASSIGN( - auto a_buffer, SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto b_buffer, SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto tuple_buffer, SharedDeviceBuffer::MakeTuple( - {a_buffer, b_buffer}, tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, + SharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, + SharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); ASSERT_EQ(tuple_buffer->children().size(), 2); EXPECT_EQ(tuple_buffer->children()[0], a_buffer); EXPECT_EQ(tuple_buffer->children()[1], b_buffer); @@ -75,30 +75,28 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) { Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); Shape c_shape = ShapeUtil::MakeShape(S64, {}); Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape}); - TF_ASSERT_OK_AND_ASSIGN( - auto a_buffer, SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto b_buffer, SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto ab_tuple_buffer, - SharedDeviceBuffer::MakeTuple({a_buffer, b_buffer}, ab_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, - nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto c_buffer, SharedDeviceBuffer::MakeArray( - c_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto abc_tuple_buffer, - SharedDeviceBuffer::MakeTuple( - {c_buffer, ab_tuple_buffer}, abc_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, + SharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, + SharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, ab_tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, + SharedDeviceBuffer::MakeArray( + c_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {c_buffer, ab_tuple_buffer}, abc_tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); Shape abc_tuple_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( abc_tuple_shape); @@ -140,7 +138,7 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) { ScopedShapedBuffer shaped_buffer, client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); std::shared_ptr device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, nullptr); + SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}); ASSERT_EQ(device_buffer->device_memory().size(), 1); ASSERT_EQ(device_buffer->children().size(), 2); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 148822f3ba7..b5f1a831d4a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -90,8 +90,3 @@ cc_library( name = "libtpu", hdrs = ["libtpu.h"], ) - -cc_library( - name = "libtftpu", - hdrs = ["libtftpu.h"], -) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 33573c1c8d8..706db57c4ac 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -35,8 +35,6 @@ limitations under the License. namespace xla { -constexpr char kTpuPlatform[] = "tpu"; - TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) : xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id), @@ -154,7 +152,7 @@ StatusOr PyTpuClient::GetDefaultDeviceAssignment( Status PyTpuClient::CheckDeviceId(int device_id, absl::string_view caller_name) { if (device_id < 0 || device_id >= device_count()) { - return InvalidArgument("%s got bad device_id: %d (num_devices=%d)", + return InvalidArgument("%s got bad device_id: %d (num_devices=%d).", caller_name, device_id, device_count()); } return Status::OK(); @@ -174,12 +172,12 @@ static Status CheckDataType(xla::PrimitiveType dtype) { StatusOr> PyTpuBuffer::FromLiterals( std::vector leaves, const Shape& tuple_shape, std::shared_ptr leaves_references, - std::shared_ptr client, int device_id) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals"); VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString() - << " device id: " << device_id; + << " device: " << device->DebugString(); TF_RETURN_IF_ERROR( - client->CheckDeviceId(device_id, "PyTpuBuffer::FromLiterals")); + client->CheckDeviceId(device->id(), "PyTpuBuffer::FromLiterals")); tpu_driver::TpuDriver* driver = client->driver(); if (!tuple_shape.IsTuple()) { @@ -193,7 +191,7 @@ StatusOr> PyTpuBuffer::FromLiterals( event->AddCallback([leaves_references](Status) {}); return event; }, - std::move(client), device_id); + std::move(client), std::move(device)); } std::vector> child_buffers; @@ -213,7 +211,7 @@ StatusOr> PyTpuBuffer::FromLiterals( [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) { return driver->TransferToDevice(leaf.untyped_data(), handle, {}); }, - client, device_id)); + client, device)); child_buffer_ptrs.push_back(child_buffer.get()); child_buffers.push_back(std::move(child_buffer)); ++it_leaf; @@ -223,13 +221,14 @@ StatusOr> PyTpuBuffer::FromLiterals( // `MakeTuple` will extract and make the tuple buffer hold onto the // `device_buffer_` contained in each `child_buffer`, so it's safe for // `child_buffers` to get destroyed before this call returns. - return MakeTuple(std::move(child_buffer_ptrs), std::move(client), device_id); + return MakeTuple(std::move(child_buffer_ptrs), std::move(client), + std::move(device)); } /* static */ StatusOr> PyTpuBuffer::MakeTuple( const std::vector buffers, - std::shared_ptr client, int device_id) { + std::shared_ptr client, std::shared_ptr device) { std::vector child_shapes; std::vector> child_device_buffers; std::vector child_handle_ptrs; @@ -252,11 +251,11 @@ StatusOr> PyTpuBuffer::MakeTuple( Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes); std::unique_ptr tuple_handle = - client->driver()->AllocateTuple(device_id, tpu_driver::MemoryRegion::HBM, - child_handle_ptrs, {}); + client->driver()->AllocateTuple( + device->id(), tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {}); auto tuple_device_buffer = std::make_shared( client->driver(), std::move(tuple_handle), std::move(child_events), - device_id); + std::move(device)); return absl::make_unique( tuple_shape, std::move(tuple_device_buffer), std::move(child_device_buffers), std::move(client)); @@ -268,7 +267,7 @@ PyTpuBuffer::PyTpuBuffer( std::shared_ptr client) : client_(std::move(client)), on_host_shape_(std::move(on_host_shape)), - device_id_(device_buffer->device_id), + device_(device_buffer->device), device_buffer_(std::move(device_buffer)), child_buffers_(std::move(child_buffers)) {} @@ -368,11 +367,11 @@ PyTpuBuffer::DestructureTuple() { if (!on_host_shape_.IsTuple()) { return InvalidArgument( "Attempted to destructure a PyTpuBuffer that did not have a tuple " - "shape; shape: %s", + "shape; shape: %s.", ShapeUtil::HumanString(on_host_shape_)); } if (DeviceBuffer() == nullptr) { - return InvalidArgument("Attempted to destructure a deleted buffer"); + return InvalidArgument("Attempted to destructure a deleted buffer."); } absl::MutexLock lock(&mu_); @@ -388,14 +387,14 @@ PyTpuBuffer::DestructureTuple() { } StatusOr> PyTpuBuffer::CopyToDevice( - int dst_device_id) { + std::shared_ptr dst_device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice"); if (on_host_shape_.IsTuple()) { return Unimplemented("CopyToDevice for tuples is not supported."); } std::shared_ptr src_device_buffer = DeviceBuffer(); - if (dst_device_id == device_id_) { + if (dst_device->id() == device_->id()) { return absl::make_unique( on_host_shape_, src_device_buffer, std::vector>(), client_); @@ -414,7 +413,7 @@ StatusOr> PyTpuBuffer::CopyToDevice( return driver->TransferFromDeviceToDevice( src_device_buffer->handle.get(), dst_handle, src_wait_for_use); }, - client_, dst_device_id)); + client_, std::move(dst_device))); // TODO(jiawenhao): This may be too pessimistic: it prevents future readers // from reading `src_device_buffer` until the device-to-device copy is done. // Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field? @@ -432,13 +431,15 @@ Status PyTpuBuffer::BlockHostUntilReady() { /* static */ StatusOr> PyTpuBuffer::AllocateBuffer( - const Shape& shape, std::shared_ptr client, int device_id) { + const Shape& shape, std::shared_ptr client, + std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer"); VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString() - << " device ordinal: " << device_id; + << " device: " << device->DebugString(); if (!shape.IsTuple()) { - return CreateBuffer(shape, absl::nullopt, std::move(client), device_id); + return CreateBuffer(shape, absl::nullopt, std::move(client), + std::move(device)); } std::vector> child_buffers; @@ -448,7 +449,7 @@ StatusOr> PyTpuBuffer::AllocateBuffer( for (const auto& child_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN(std::unique_ptr child_buffer, - AllocateBuffer(child_shape, client, device_id)); + AllocateBuffer(child_shape, client, device)); child_buffer_ptrs.push_back(child_buffer.get()); child_buffers.push_back(std::move(child_buffer)); } @@ -457,21 +458,23 @@ StatusOr> PyTpuBuffer::AllocateBuffer( // `device_buffer_` contained in each `child_buffer`, so it's safe for // `child_buffers` to get destroyed before this call returns. return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client), - device_id); + std::move(device)); } /*static*/ StatusOr> PyTpuBuffer::CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, int device_id) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer"); VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: " - << non_tuple_shape.DebugString() << " device id: " << device_id; + << non_tuple_shape.DebugString() + << " device: " << device->DebugString(); TF_RET_CHECK(!non_tuple_shape.IsTuple()); TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type())); - std::unique_ptr handle = client->driver()->Allocate( - device_id, tpu_driver::MemoryRegion::HBM, non_tuple_shape.ToProto(), {}); + std::unique_ptr handle = + client->driver()->Allocate(device->id(), tpu_driver::MemoryRegion::HBM, + non_tuple_shape.ToProto(), {}); // If this buffer needs to be initialized, anyone using this buffer must wait // for the initialization event in `wait_for_use` to finish first. @@ -481,7 +484,8 @@ StatusOr> PyTpuBuffer::CreateBuffer( wait_for_use.push_back(std::move(init)); } auto device_buffer = std::make_shared( - client->driver(), std::move(handle), std::move(wait_for_use), device_id); + client->driver(), std::move(handle), std::move(wait_for_use), + std::move(device)); return absl::make_unique( non_tuple_shape, std::move(device_buffer), @@ -542,7 +546,8 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( << " mapped to device id for execution: " << device_id; std::unique_ptr<::xla::PyTpuBuffer> output_buffer = - ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_, device_id) + ::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_, + std::move(device)) .ValueOrDie(); VLOG(1) << "Created output buffer: " << result_shape_.DebugString(); @@ -610,12 +615,12 @@ StatusOr> PyTpuExecutable::Execute( absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( - "Attempted to execute computation with %d replicas using Execute()", + "Attempted to execute computation with %d replicas using Execute().", num_replicas()); } if (num_partitions() != 1) { return InvalidArgument( - "Attempted to execute computation with %d partitions using Execute()", + "Attempted to execute computation with %d partitions using Execute().", num_partitions()); } @@ -636,19 +641,6 @@ StatusOr> PyTpuExecutable::Execute( return std::move(result.buffer); } -StatusOr>> -PyTpuExecutable::ExecutePerReplica( - absl::Span> argument_handles) { - tensorflow::profiler::TraceMe traceme("PyTpuExecutable::ExecutePerReplica"); - if (num_partitions() != 1) { - return InvalidArgument( - "Attempted to execute computation with %d partitions using " - "ExecutePerReplica()", - num_partitions()); - } - return ExecuteOnLocalDevices(argument_handles); -} - StatusOr>> PyTpuExecutable::ExecuteOnLocalDevices( absl::Span> argument_handles) { @@ -660,7 +652,7 @@ PyTpuExecutable::ExecuteOnLocalDevices( if (argument_handles.size() != num_local_devices) { return InvalidArgument( "Attempted to execute with %d argument lists when local device " - "count is %d (total replica count: %d, partition count: %d)", + "count is %d (total replica count: %d, partition count: %d).", argument_handles.size(), num_local_devices, num_replicas(), num_partitions()); } @@ -717,54 +709,6 @@ PyTpuExecutable::ExecuteOnLocalDevices( return wrapped_results; } -/*static*/ StatusOr> -PyTpuExecutable::CompileForDevices( - const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - const std::vector>>& - device_assignment) { - if (device_assignment.empty()) { - return InvalidArgument( - "Device assignment passed to Compile() must be non-empty."); - } - if (device_assignment[0].empty()) { - return InvalidArgument( - "Device assignment passed to Compile() must have a nonzero number of " - "partitions per replica; replica 0 had 0 partitions."); - } - DeviceAssignment xla_assignment(device_assignment.size(), - device_assignment[0].size()); - for (int replica = 0; replica < device_assignment.size(); ++replica) { - if (device_assignment[replica].size() != device_assignment[0].size()) { - return InvalidArgument( - "Device assignment passed to Compile() has different numbers of " - "partitions between replicas; %d partitions for replica %d versus %d " - "partitions for replica 0.", - device_assignment[replica].size(), replica, - device_assignment[0].size()); - } - for (int partition = 0; partition < device_assignment[replica].size(); - ++partition) { - if (device_assignment[0][0]->platform_name() != - device_assignment[replica][partition]->platform_name()) { - return InvalidArgument( - "Device assignment passed to Compile() must have devices of a " - "single kind, got %s for replica 0 partition 0 and %s for replica " - "%d partition %d.", - device_assignment[0][0]->platform_name(), - device_assignment[replica][partition]->platform_name(), replica, - partition); - } - xla_assignment(replica, partition) = - device_assignment[replica][partition]->id(); - } - } - return Compile(computation, std::move(argument_layouts), build_options, - std::move(client), xla_assignment); -} - /*static*/ StatusOr> PyTpuExecutable::Compile( const XlaComputation& computation, absl::optional> argument_layouts, diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index f4815b44183..4b7670707fb 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -36,6 +36,8 @@ limitations under the License. namespace xla { +constexpr char kTpuPlatform[] = "tpu"; + class TpuDevice : public Device { public: TpuDevice(int id, int host_id, const std::array& coords, @@ -126,9 +128,9 @@ struct TpuSharedBuffer final { TpuSharedBuffer(tpu_driver::TpuDriver* driver, std::unique_ptr handle, std::vector> wait_for_use, - int device_id) + std::shared_ptr src_device) : driver(driver), - device_id(device_id), + device(std::move(src_device)), handle(std::move(handle)), wait_for_use(std::move(wait_for_use)) {} @@ -141,7 +143,7 @@ struct TpuSharedBuffer final { } tpu_driver::TpuDriver* const driver; - const int device_id; + const std::shared_ptr device; std::unique_ptr handle; std::vector> wait_for_use; @@ -160,12 +162,12 @@ class PyTpuBuffer { static StatusOr> FromLiterals( std::vector leaves_literals, const Shape& tuple_shape, std::shared_ptr leaves_reference, - std::shared_ptr client, int device_id); + std::shared_ptr client, std::shared_ptr device); // Supports nested tuple creation. static StatusOr> MakeTuple( const std::vector buffers, - std::shared_ptr client, int device_id); + std::shared_ptr client, std::shared_ptr device); PyTpuBuffer() = delete; PyTpuBuffer(Shape on_host_shape, @@ -179,7 +181,7 @@ class PyTpuBuffer { PyTpuBuffer& operator=(PyTpuBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } - int device_id() const { return device_id_; } + std::shared_ptr device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } std::shared_ptr client() const { return client_; } @@ -205,8 +207,10 @@ class PyTpuBuffer { // Destructures a tuple-valued PyTpuBuffer into its constituent elements. StatusOr>> DestructureTuple(); - // Copies the buffer to device `dst_device_id`. - StatusOr> CopyToDevice(int dst_device_id); + // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer + // object holding the context to the target device buffer. + StatusOr> CopyToDevice( + std::shared_ptr dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -215,7 +219,8 @@ class PyTpuBuffer { // Allocates uninitialized buffers on device `device_id`. If `shape` is a // tuple, the returned buffer corresponds to the root tuple buffer. static StatusOr> AllocateBuffer( - const Shape& shape, std::shared_ptr client, int device_id); + const Shape& shape, std::shared_ptr client, + std::shared_ptr device); private: // Initializes a just allocated device buffer. The returned event will be @@ -226,18 +231,19 @@ class PyTpuBuffer { static StatusOr> CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, int device_id); + std::shared_ptr client, std::shared_ptr device); const std::shared_ptr client_; const Shape on_host_shape_; - const int device_id_; + const std::shared_ptr device_; // If this is a tuple, `device_buffer_` stores the tuple buffer and // `child_buffers_` stores the child buffers; else, `device_buffer_` stores // the data content and `child_buffers_` is empty. mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ GUARDED_BY(mu_); - std::vector> child_buffers_ GUARDED_BY(mu_); + std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); + std::vector> child_buffers_ + TF_GUARDED_BY(mu_); // The cached value of the buffer on the host, produced either from a call to // CopyToHost or from a call to ToLiteral. Once a value has been fetched to // the host, it persists Delete() is called or the PyTpuBuffer is destroyed. @@ -250,23 +256,13 @@ class PyTpuBuffer { Status status; std::shared_ptr value; }; - std::shared_ptr host_value_ GUARDED_BY(mu_); + std::shared_ptr host_value_ TF_GUARDED_BY(mu_); }; // Represents a compiled computation that can be executed given handles to // device-allocated literals. Wraps an XLA LocalExecutable. class PyTpuExecutable { public: - // Compiles a computation to an executable. - static StatusOr> CompileForDevices( - const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - const std::vector>>& - device_assignment); - - // TODO(phawkins): remove after changing callers to use the first overload. static StatusOr> Compile( const XlaComputation& computation, absl::optional> argument_layouts, @@ -309,20 +305,12 @@ class PyTpuExecutable { return local_devices_; } - // TODO(power): Both Execute and ExecutePerReplica block and wait inside for - // computation to finish. Coordinate with JAX code change to see if we can - // make both Execute and ExecutePerReplica non-blocking. + // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait + // inside for computation to finish. Coordinate with JAX code change to see if + // we can make both Execute and ExecutePerReplica non-blocking. StatusOr> Execute( absl::Span argument_handles); - // Execute on many replicas. Takes a sequence of argument lists (one argument - // list per replica) and returns a tuple of results (one result per replica). - // The number of argument lists must be equal to the replica count. - // The executable must have only one partition. - // TODO(cjfj): Remove this once JAX is moved to `ExecuteOnLocalDevices`. - StatusOr>> ExecutePerReplica( - absl::Span> argument_handles); - // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index f6e2fab7ef0..0dcb9dc4c84 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -38,7 +38,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("local_devices", &PyTpuClient::local_devices) .def("host_id", &PyTpuClient::host_id) .def("GetDefaultDeviceAssignment", - [](PyLocalClient* client, int num_replicas, int num_partitions) + [](PyTpuClient* client, int num_replicas, int num_partitions) -> StatusOr>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( @@ -121,9 +121,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { std::make_move_iterator(tree.leaves.end())); py::gil_scoped_release gil_release; - return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape, - std::move(py_buffer_ref), - std::move(client), device->id()); + return PyTpuBuffer::FromLiterals( + std::move(leaves), tree.shape, std::move(py_buffer_ref), + std::move(client), std::move(device)); }) .def_static("make_tuple", [](const std::vector buffers, @@ -137,15 +137,15 @@ PYBIND11_MODULE(tpu_client_extension, m) { "Cannot make tuple on device '%s' with '%s' backend", device->DebugString(), client->platform_name()); } - return PyTpuBuffer::MakeTuple(buffers, client, - device->id()); + return PyTpuBuffer::MakeTuple(buffers, std::move(client), + std::move(device)); }) .def("copy_to_device", [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(dst_device->id()); + return buffer->CopyToDevice(std::move(dst_device)); }) .def("delete", &PyTpuBuffer::Delete) .def("destructure", &PyTpuBuffer::DestructureTuple) @@ -168,10 +168,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { return LiteralToPython(std::move(literal)); }) .def("shape", &PyTpuBuffer::on_host_shape) - .def("device", - [](PyTpuBuffer* buffer) -> std::shared_ptr { - return buffer->client()->devices()[buffer->device_id()]; - }) + .def("device", &PyTpuBuffer::device) .def("platform", &PyTpuBuffer::platform_name) .def("is_deleted", [](const PyTpuBuffer& buffer) { return buffer.DeviceBuffer() == nullptr; @@ -180,8 +177,25 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_(m, "TpuExecutable") .def_static("Compile", &PyTpuExecutable::Compile, py::call_guard()) - .def_static("Compile", &PyTpuExecutable::CompileForDevices, - py::call_guard()) + .def_static("Compile", + [](const XlaComputation& computation, + absl::optional> argument_layouts, + const ExecutableBuildOptions* build_options, + std::shared_ptr client, + absl::optional>> + device_assignment) + -> StatusOr> { + py::gil_scoped_release gil_release; + absl::optional xla_device_assignment; + if (device_assignment) { + TF_ASSIGN_OR_RETURN( + xla_device_assignment, + DevicesToDeviceAssignment(*device_assignment)); + } + return PyTpuExecutable::Compile( + computation, argument_layouts, build_options, client, + std::move(xla_device_assignment)); + }) .def("local_logical_device_ids", &PyTpuExecutable::local_logical_device_ids) .def("local_devices", &PyTpuExecutable::local_devices) @@ -190,8 +204,6 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, py::call_guard(), py::arg("arguments")) - .def("ExecutePerReplica", &PyTpuExecutable::ExecutePerReplica, - py::call_guard(), py::arg("arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, py::call_guard(), py::arg("arguments")); diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index 9127f0342fa..72e55b1d11e 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -23,7 +23,6 @@ #include #include -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index ceefbda4f90..564933c14f2 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -26,6 +26,7 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,6 +36,97 @@ limitations under the License. namespace xla { +// Custom holder types. +// +// We must keep the PyLocalClient object alive as long as any of the runtime +// objects are alive. Since we don't have a lot of control over Python +// destructor ordering, we keep the PyLocalClient object as a std::shared_ptr<>, +// and ensure that each Python runtime object holds a reference to the +// PyLocalClient. An alternative design would be to keep a single global +// singleton PyLocalClient, although this seems less flexible, especially for +// writing tests. +// +// To maintain PyLocalClient references, we define pybind11 holder classes that +// are custom smart pointers that also keep a reference to a PyLocalClient. +// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't +// seem sufficiently flexible to describe ownership relationships in cases where +// the ownership doesn't pertain to a direct argument or return value of a +// function. Another alternative to the holder classes would be to create proxy +// objects that contain both a reference and a runtime class; holder classes +// seem less tedious to define. + +// A pair of a PyLocalClient reference and an unowned pointer to T. +template +struct ClientAndPtr { + ClientAndPtr() = default; + // pybind11 requires that we define a constructor that takes a raw pointer, + // but it should be unreachable. + explicit ClientAndPtr(T*) { + LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient."; + } + + ClientAndPtr(const ClientAndPtr&) = default; + ClientAndPtr(ClientAndPtr&&) = default; + ClientAndPtr& operator=(const ClientAndPtr&) = default; + ClientAndPtr& operator=(ClientAndPtr&&) = default; + + std::shared_ptr client; + T* contents; + + T* get() const { return contents; } + T* operator->() const { return contents; } + T& operator*() const { return *contents; } +}; + +// By defining a templated helper function, we can use return type deduction +// and avoid specifying types at the caller. +template +ClientAndPtr WrapWithClient(std::shared_ptr client, + T* contents) { + ClientAndPtr result; + result.client = std::move(client); + result.contents = contents; + return result; +} + +// A pair of a PyLocalClient reference and an owned pointer to T. +template +struct ClientAndUniquePtr { + ClientAndUniquePtr() = default; + // pybind11 requires that we define a constructor that takes a raw pointer, + // but it should be unreachable. + explicit ClientAndUniquePtr(T*) { + LOG(FATAL) << "ClientAndUniquePtr should constructed via WrapWithClient."; + } + ClientAndUniquePtr(const ClientAndUniquePtr&) = delete; + ClientAndUniquePtr(ClientAndUniquePtr&&) = default; + ClientAndUniquePtr& operator=(const ClientAndUniquePtr&) = delete; + ClientAndUniquePtr& operator=(ClientAndUniquePtr&&) = default; + + std::shared_ptr client; + std::unique_ptr contents; + + T* get() const { return contents.get(); } + T* operator->() const { return contents.get(); } + T& operator*() const { return *contents; } +}; + +template +ClientAndUniquePtr WrapWithClient(std::shared_ptr client, + std::unique_ptr contents) { + ClientAndUniquePtr result; + result.client = std::move(client); + result.contents = std::move(contents); + return result; +} + +} // namespace xla + +PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr); +PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndUniquePtr); + +namespace xla { + // Initializes the NumPy API for the use of the types module. bool InitializeNumpyAPIForTypes(); diff --git a/tensorflow/compiler/xla/python/worker_thread.h b/tensorflow/compiler/xla/python/worker_thread.h index bc7dd396f88..598f7b1d4ae 100644 --- a/tensorflow/compiler/xla/python/worker_thread.h +++ b/tensorflow/compiler/xla/python/worker_thread.h @@ -40,11 +40,11 @@ class WorkerThread { void Schedule(std::function fn); private: - bool WorkAvailable() EXCLUSIVE_LOCKS_REQUIRED(mu_); + bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); void WorkLoop(); absl::Mutex mu_; - std::queue> work_queue_ GUARDED_BY(mu_); + std::queue> work_queue_ TF_GUARDED_BY(mu_); std::unique_ptr thread_; }; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 4be375ac15a..b42202ca838 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "include/pybind11/cast.h" #include "include/pybind11/numpy.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/pytypes.h" @@ -40,6 +41,9 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/cpu_device.h" +#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/python/distributed/distributed.h" +#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" @@ -70,7 +74,7 @@ namespace { struct Uniquer { absl::Mutex mu; - NameUniquer name_uniquer GUARDED_BY(mu); + NameUniquer name_uniquer TF_GUARDED_BY(mu); }; Uniquer* GetUniquer() { @@ -153,16 +157,6 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name, return Status::OK(); } -StatusOr> LookupDeviceOrdinal( - PyLocalClient* client, int device_ordinal, absl::string_view caller_name) { - if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) { - return InvalidArgument( - "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name, - device_ordinal, client->local_device_count()); - } - return client->local_devices()[device_ordinal]; -} - // PEP 3118 buffer protocol implementation. // Extra data to be kept alive by the consumer of the buffer protocol. @@ -552,9 +546,7 @@ class TraceMeContextManager { void Enter() { if (IsEnabled()) { std::string name(name_); - // TODO(skye): we can use kwargs_.empty() once we upgrade to pybind11 2.4 - // in workspace.bzl - if (kwargs_.size() != 0) { + if (!kwargs_.empty()) { absl::StrAppend(&name, "#"); bool first = true; for (const auto& entry : kwargs_) { @@ -764,7 +756,7 @@ PYBIND11_MODULE(xla_extension, m) { // Literals py::class_>(m, "Literal") .def("__repr__", &Literal::ToString); - py::class_(m, "LiteralSlice"); + py::class_ literal_slice(m, "LiteralSlice"); py::implicitly_convertible(); py::implicitly_convertible(); @@ -790,7 +782,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("computation_count", &DeviceAssignment::computation_count) .def("__repr__", &DeviceAssignment::ToString); - py::class_>( + py::class_>( m, "Device", "A descriptor of an available device.\n\nSubclasses are used to " "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " @@ -832,12 +824,12 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal_shared)); }); - py::class_>(m, "CpuDevice") + py::class_>(m, "CpuDevice") .def("__repr__", [](const CpuDevice& device) { return absl::StrFormat("CpuDevice(id=%i)", device.id()); }); - py::class_>(m, "GpuDevice") + py::class_>(m, "GpuDevice") .def("__repr__", [](const GpuDevice& device) { return absl::StrFormat("GpuDevice(id=%i)", device.id()); }); @@ -860,16 +852,33 @@ PYBIND11_MODULE(xla_extension, m) { py::class_>(m, "LocalClient") .def("device_count", &PyLocalClient::device_count) .def("local_device_count", &PyLocalClient::local_device_count) - .def("devices", &PyLocalClient::devices) - .def("local_devices", &PyLocalClient::local_devices) + .def("devices", + [](std::shared_ptr client) { + std::vector> devices; + devices.reserve(client->devices().size()); + for (const auto& device : client->devices()) { + devices.push_back(WrapWithClient(client, device.get())); + } + return devices; + }) + .def("local_devices", + [](std::shared_ptr client) { + std::vector> devices; + devices.reserve(client->local_devices().size()); + for (Device* device : client->local_devices()) { + devices.push_back(WrapWithClient(client, device)); + } + return devices; + }) .def("host_id", &PyLocalClient::host_id) .def("GetDefaultDeviceAssignment", - [](PyLocalClient* client, int num_replicas, int num_partitions) - -> StatusOr>>> { + [](std::shared_ptr client, int num_replicas, + int num_partitions) + -> StatusOr>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, num_partitions)); - std::vector>> result; + std::vector>> result; result.resize(num_replicas); for (int r = 0; r < num_replicas; ++r) { result[r].resize(num_partitions); @@ -877,24 +886,24 @@ PYBIND11_MODULE(xla_extension, m) { int device_id = device_assignment(r, p); auto iter = client->id_to_device().find(device_id); CHECK(iter != client->id_to_device().end()) << device_id; - result[r][p] = iter->second; + result[r][p] = WrapWithClient(client, iter->second); } } return result; }) // TODO(skye): delete after all callers can handle 2D output .def("GetDefaultDeviceAssignment", - [](PyLocalClient* client, int num_replicas) - -> StatusOr>> { + [](std::shared_ptr client, + int num_replicas) -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, /*num_partitions=*/1)); - std::vector> result; + std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); auto iter = client->id_to_device().find(device_id); CHECK(iter != client->id_to_device().end()) << device_id; - result.push_back(iter->second); + result.push_back(WrapWithClient(client, iter->second)); } return result; }) @@ -913,16 +922,17 @@ PYBIND11_MODULE(xla_extension, m) { m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true); m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient, py::arg("asynchronous") = true, - py::arg("allocator_config") = GpuAllocatorConfig()); + py::arg("allocator_config") = GpuAllocatorConfig(), + py::arg("distributed_client") = nullptr, py::arg("node_id") = 0); - py::class_ buffer(m, "PyLocalBuffer"); + py::class_> buffer( + m, "PyLocalBuffer"); buffer .def_static( "from_python", [](const pybind11::object& argument, - std::shared_ptr client, - std::shared_ptr device, - bool force_copy) -> StatusOr> { + std::shared_ptr client, Device* device, + bool force_copy) -> StatusOr> { CHECK(device != nullptr); auto iter = client->id_to_device().find(device->id()); if (iter->second != device) { @@ -943,36 +953,57 @@ PYBIND11_MODULE(xla_extension, m) { GlobalPyRefManager()->ManageReference(std::move(c->array)); py::gil_scoped_release gil_release; - return PyLocalBuffer::FromHostBuffer( - c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref), - std::move(client), std::move(device)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + PyLocalBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy, + std::move(py_buffer_ref), + client.get(), device)); + return WrapWithClient(std::move(client), std::move(buffer)); }, py::arg("argument"), py::arg("client"), py::arg("device"), py::arg("force_copy") = false) - .def_static("make_tuple", - [](const std::vector buffers, - std::shared_ptr client, - std::shared_ptr device) - -> StatusOr> { - CHECK(device != nullptr); - auto iter = client->id_to_device().find(device->id()); - if (iter->second != device) { - return InvalidArgument( - "Cannot make tuple on device '%s' with '%s' backend", - device->DebugString(), client->platform_name()); - } - return PyLocalBuffer::MakeTuple(buffers, std::move(client), - std::move(device)); - }) + .def_static( + "make_tuple", + [](std::vector buffers, + std::shared_ptr client, + Device* device) -> StatusOr> { + CHECK(device != nullptr); + auto iter = client->id_to_device().find(device->id()); + if (iter->second != device) { + return InvalidArgument( + "Cannot make tuple on device '%s' with '%s' backend", + device->DebugString(), client->platform_name()); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + PyLocalBuffer::MakeTuple(buffers, client.get(), device)); + return WrapWithClient(std::move(client), std::move(buffer)); + }) .def("copy_to_device", - [](PyLocalBuffer* buffer, std::shared_ptr dst_device) { - CHECK(dst_device != nullptr); + [](PyLocalBuffer* buffer, const ClientAndPtr& dst_device) + -> StatusOr> { + CHECK(dst_device.get() != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(std::move(dst_device)); + TF_ASSIGN_OR_RETURN(std::unique_ptr out, + buffer->CopyToDevice(dst_device.get())); + return WrapWithClient(dst_device.client, std::move(out)); }) .def("delete", &PyLocalBuffer::Delete) - .def("destructure", &PyLocalBuffer::DestructureTuple) + .def("destructure", + [](const PyLocalBuffer& buffer) + -> StatusOr>> { + TF_ASSIGN_OR_RETURN( + std::vector> parts, + buffer.DestructureTuple()); + std::vector> output; + output.reserve(parts.size()); + for (auto& part : parts) { + output.push_back(WrapWithClient( + buffer.client()->shared_from_this(), std::move(part))); + } + return std::move(output); + }) .def("block_host_until_ready", [](PyLocalBuffer* buffer) { GlobalPyRefManager()->CollectGarbage(); @@ -1004,7 +1035,11 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal)); }) .def("shape", &PyLocalBuffer::on_host_shape) - .def("device", &PyLocalBuffer::device) + .def("device", + [](const PyLocalBuffer& buffer) { + return WrapWithClient(buffer.client()->shared_from_this(), + buffer.device()); + }) .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", [](const PyLocalBuffer& buffer) { @@ -1030,24 +1065,160 @@ PYBIND11_MODULE(xla_extension, m) { PyTypeObject* buffer_type = reinterpret_cast(buffer.ptr()); buffer_type->tp_as_buffer = &PyLocalBufferProcs; - py::class_(m, "LocalExecutable") - .def_static("Compile", &PyLocalExecutable::Compile, - py::call_guard()) - .def_static("Compile", &PyLocalExecutable::CompileForDevices, - py::call_guard()) + py::class_> + executable(m, "LocalExecutable"); + executable + .def_static("Compile", + [](const XlaComputation& computation, + absl::optional> argument_layouts, + const ExecutableBuildOptions* build_options, + std::shared_ptr client, + absl::optional device_assignment) + -> StatusOr> { + py::gil_scoped_release gil_release; + CompileOptions options; + options.argument_layouts = std::move(argument_layouts); + if (build_options) { + options.executable_build_options = *build_options; + } + if (device_assignment) { + options.executable_build_options.set_device_assignment( + *device_assignment); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + PyLocalExecutable::Compile(computation, client.get(), + std::move(options))); + return WrapWithClient(std::move(client), + std::move(executable)); + }) + .def_static("Compile", + [](const XlaComputation& computation, + absl::optional> argument_layouts, + const ExecutableBuildOptions* build_options, + std::shared_ptr client, + absl::optional>> + device_assignment) + -> StatusOr> { + py::gil_scoped_release gil_release; + CompileOptions options; + options.argument_layouts = std::move(argument_layouts); + if (build_options) { + options.executable_build_options = *build_options; + } + if (device_assignment) { + TF_ASSIGN_OR_RETURN( + DeviceAssignment xla_assignment, + DevicesToDeviceAssignment(*device_assignment)); + options.executable_build_options.set_device_assignment( + xla_assignment); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + PyLocalExecutable::Compile(computation, client.get(), + std::move(options))); + return WrapWithClient(std::move(client), + std::move(executable)); + }) .def("local_logical_device_ids", &PyLocalExecutable::local_logical_device_ids) - .def("local_devices", &PyLocalExecutable::local_devices) + .def("local_devices", + [](const PyLocalExecutable& executable) { + std::vector> devices; + devices.reserve(executable.local_devices().size()); + for (Device* device : executable.local_devices()) { + devices.push_back(WrapWithClient( + executable.client()->shared_from_this(), device)); + } + return devices; + }) .def("SizeOfGeneratedCodeInBytes", &PyLocalExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyLocalExecutable::Delete) - .def("Execute", &PyLocalExecutable::Execute, - py::call_guard(), py::arg("arguments")) - // TODO(phawkins): remove when all callers switch to ExecuteOnLocalDevices - .def("ExecutePerReplica", &PyLocalExecutable::ExecutePerReplica, - py::call_guard(), py::arg("arguments")) - .def("ExecuteOnLocalDevices", &PyLocalExecutable::ExecuteOnLocalDevices, - py::call_guard(), py::arg("arguments")) + .def( + "Execute", + [](const PyLocalExecutable& executable, + absl::Span args) + -> StatusOr> { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + std::vector> output, + executable.Execute(args, ExecuteOptions())); + return WrapWithClient(executable.client()->shared_from_this(), + std::move(output.front())); + }, + py::arg("arguments")) + // TODO(phawkins): remove in favor of overload that returns a vector. + .def( + "Execute", + [](const PyLocalExecutable& executable, + absl::Span args, bool tuple_arguments) + -> StatusOr>> { + py::gil_scoped_release gil_release; + ExecuteOptions options; + options.tuple_arguments = tuple_arguments; + options.untuple_result = true; + TF_ASSIGN_OR_RETURN( + std::vector> output_buffers, + executable.Execute(args, options)); + std::vector> outputs; + outputs.reserve(output_buffers.size()); + for (auto& buffer : output_buffers) { + outputs.push_back(WrapWithClient( + executable.client()->shared_from_this(), std::move(buffer))); + } + return outputs; + }, + py::arg("arguments"), py::arg("tuple_arguments")) + // TODO(phawkins): remove in favor of overload that returns a vector. + .def( + "ExecuteOnLocalDevices", + [](const PyLocalExecutable& executable, + absl::Span> args) + -> StatusOr>> { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + std::vector>> + output_buffers, + executable.ExecuteOnLocalDevices(args, ExecuteOptions())); + std::vector> outputs; + outputs.reserve(output_buffers.size()); + for (auto& buffers : output_buffers) { + outputs.push_back( + WrapWithClient(executable.client()->shared_from_this(), + std::move(buffers.front()))); + } + return outputs; + }, + py::arg("arguments")) + .def( + "ExecuteOnLocalDevices", + [](const PyLocalExecutable& executable, + absl::Span> args, + bool tuple_arguments) + -> StatusOr< + std::vector>>> { + py::gil_scoped_release gil_release; + ExecuteOptions options; + options.tuple_arguments = tuple_arguments; + options.untuple_result = true; + TF_ASSIGN_OR_RETURN( + std::vector>> + output_buffers, + executable.ExecuteOnLocalDevices(args, options)); + std::vector>> outputs; + outputs.resize(output_buffers.size()); + for (int computation = 0; computation < output_buffers.size(); + ++computation) { + for (auto& buffer : output_buffers[computation]) { + outputs[computation].push_back( + WrapWithClient(executable.client()->shared_from_this(), + std::move(buffer))); + } + } + return outputs; + }, + py::arg("arguments"), py::arg("tuple_arguments")) .def( "get_hlo_modules", [](const PyLocalExecutable& executable) @@ -1208,7 +1379,14 @@ PYBIND11_MODULE(xla_extension, m) { .def("ClearSharding", &XlaBuilder::ClearSharding); m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); - m.def("DLPackManagedTensorToBuffer", DLPackManagedTensorToBuffer); + m.def("DLPackManagedTensorToBuffer", + [](const py::capsule& tensor, std::shared_ptr client) + -> StatusOr> { + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + DLPackManagedTensorToBuffer(tensor, client.get())); + return WrapWithClient(std::move(client), std::move(buffer)); + }); py::enum_( m, "TriangularSolveOptions_Transpose") @@ -1247,6 +1425,16 @@ PYBIND11_MODULE(xla_extension, m) { BuildOpsSubmodule(&m); BuildProfilerSubmodule(&m); + + py::class_> + distributed_runtime_service(m, "DistributedRuntimeService"); + py::class_> + distributed_runtime_client(m, "DistributedRuntimeClient"); + + m.def("get_distributed_runtime_service", &GetDistributedRuntimeService); + m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9d53f9bd082..f1f31a5eb89 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -177,7 +177,7 @@ def _cpu_backend_factory(): return LocalBackend(platform='cpu', client=client) -def _gpu_backend_factory(): +def _gpu_backend_factory(distributed_client=None, node_id=0): """Returns a GPU backend. BFC allocator is used by default.""" allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') @@ -197,8 +197,11 @@ def _gpu_backend_factory(): config.memory_fraction = float(memory_fraction) config.preallocate = preallocate not in ('0', 'false', 'False') - client = _xla.get_nvidia_gpu_client(asynchronous=True, - allocator_config=config) + client = _xla.get_nvidia_gpu_client( + asynchronous=True, + allocator_config=config, + distributed_client=distributed_client, + node_id=node_id) return LocalBackend(platform='gpu', client=client) @@ -604,17 +607,17 @@ class Computation(object): # def SizeOfGeneratedCodeInBytes(self) -> int: # """Return generated binary size, or -1 if not known.""" # -# def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]: +# def ExecuteOnLocalDevices(self, arguments: [[Buffer]]) -> [Buffer]: # """Execute on many replicas with Buffer arguments and return value. # # Args: # arguments: A sequence of sequences of Buffers. The i'th inner sequence -# comprises the arguments for execution on the i'th replica. +# comprises the arguments for execution on the i'th local device. # # Returns: -# A list of the computation's outputs for each replica, as a Buffer. If -# a shallow sequence of arguments was passed in for `arguments`, then the -# sole, zero'th replica's output is returned instead, as a Buffer. +# A list of the computation's outputs for each local device, as a Buffer. +# If a shallow sequence of arguments was passed in for `arguments`, then +# the sole, zero'th device's output is returned instead, as a Buffer. # """ # # There are different implementations of Executable for different backends. @@ -658,7 +661,7 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): for replica_args in arguments: arg_buffers.append(flat_arg_buffers[:len(replica_args)]) flat_arg_buffers = flat_arg_buffers[len(replica_args):] - return [out.to_py() for out in executable.ExecutePerReplica(arg_buffers)] + return [out.to_py() for out in executable.ExecuteOnLocalDevices(arg_buffers)] class PaddingType(enum.Enum): diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 502f0fa7927..98851fddd2d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3133,6 +3133,7 @@ cc_library( hdrs = ["hlo_dce.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fd373671b97..1f36d906e73 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3647,7 +3647,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. - if (arg->opcode() == HloOpcode::kReshape) { + if (options_.enable_reduce_of_reshape() && + arg->opcode() == HloOpcode::kReshape) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index ce364a16134..4251e7eb846 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -107,6 +107,12 @@ class AlgebraicSimplifierOptions { return metadata_.cudnn_batchnorm_forward_training_metadata; } + void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { + enable_reduce_of_reshape_ = enable_reduce_of_reshape; + } + + bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplierOptions that can be later used in an @@ -126,6 +132,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; + bool enable_reduce_of_reshape_{true}; int64 very_small_gather_size_{4}; Metadata metadata_; }; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 6e7f9fdfc13..06b55e24b69 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -87,7 +87,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. StatusOr> ResolveInternal( - const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const GlobalDataHandle& data) const TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If @@ -96,18 +96,19 @@ class AllocationTracker { template StatusOr RegisterInternal( std::vector replicated_buffers, const string& tag) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already // exists, then increment its reference count. void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory, int device_ordinal) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Decrements the reference count of the given device memory. Then, if it is // zero, deallocate the memory. Status DecrementRefCount(se::DeviceMemoryBase device_memory, - int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + int device_ordinal) + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // A map from device memory opaque value to allocation. One such map is // maintained per device ordinal. @@ -121,11 +122,11 @@ class AllocationTracker { // The next handle to assign to an allocation, guarded by the same mutex as // the mapping as they'll be mutated at the same time. - int64 next_handle_ GUARDED_BY(mutex_); + int64 next_handle_ TF_GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. absl::flat_hash_map opaque_to_allocation_map_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the // buffers for different replicas. @@ -145,7 +146,7 @@ class AllocationTracker { // free'd when both the view *and* the original tuple are Unregistered. This // refcounting is managed in opaque_to_allocation_map_. absl::flat_hash_map>> - handle_to_shaped_buffers_ GUARDED_BY(mutex_); + handle_to_shaped_buffers_ TF_GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 79fdeb2b0bc..2e2284a3e23 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -176,7 +176,7 @@ class Backend { // Mapping from stream executor to stream pools, used by `BorrowStream` above. absl::flat_hash_map> - stream_pools_ GUARDED_BY(mu_); + stream_pools_ TF_GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 89e17eba36f..a02cda91a69 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -68,24 +68,24 @@ class ChannelTracker { // Bumps the next_channel_ number and returns the allocated number // wrapped in a ChannelHandle. ChannelHandle AllocateHandle(ChannelHandle::ChannelType type) - EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + TF_EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); Status RegisterSendInternal(const ChannelHandle& handle) - EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + TF_EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); Status RegisterRecvInternal(const ChannelHandle& handle) - EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + TF_EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); // Guards the channel mapping. tensorflow::mutex channel_mutex_; // The next sequence number to assign to a channel. - int64 next_channel_ GUARDED_BY(channel_mutex_); + int64 next_channel_ TF_GUARDED_BY(channel_mutex_); // Mapping from ChannelHandle value to the corresponding registered // Channel object. absl::flat_hash_map opaque_to_channel_ - GUARDED_BY(channel_mutex_); + TF_GUARDED_BY(channel_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); }; diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index d9b6c48685b..6af05d925aa 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -241,9 +241,9 @@ class Rendezvous { tensorflow::mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; + bool initialized_ TF_GUARDED_BY(mu_) = false; - std::vector participants_ GUARDED_BY(mu_); + std::vector participants_ TF_GUARDED_BY(mu_); private: // Runs the all-reduce on the given thread. If successful, returns diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h index 5f94def509d..22c1af8fdab 100644 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -51,7 +51,7 @@ class CompilationCache { using CacheKey = int64; absl::flat_hash_map> cache_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); private: TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index c07c3eb3c3b..6bfd8c4db46 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -975,8 +975,9 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { + for (HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); } else if (instruction->opcode() == HloOpcode::kConditional) { diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc index e624e5cc7eb..244d7d4c539 100644 --- a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc @@ -23,7 +23,7 @@ namespace orc_jit_memory_mapper { static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED); static llvm::SectionMemoryManager::MemoryMapper* mapper_instance - GUARDED_BY(mapper_instance_mutex) = nullptr; + TF_GUARDED_BY(mapper_instance_mutex) = nullptr; llvm::SectionMemoryManager::MemoryMapper* GetInstance() { tensorflow::mutex_lock lock(mapper_instance_mutex); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 2f8be8c111b..cbbc4d7bf34 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -273,7 +273,8 @@ class VectorSupportLibrary { llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) { llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); if (llvm::isa(type)) { - return llvm::ConstantVector::getSplat(vector_size(), scalar_value); + return llvm::ConstantVector::getSplat( + llvm::ElementCount(vector_size(), /*Scalable=*/false), scalar_value); } return scalar_value; } diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 3cb0eb78c5b..ca6fadc2e23 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -274,7 +274,7 @@ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); // dies. But we only add an entry if dumping is enabled for this module, and // dumping a module leaks buffer space in stdout or bytes on disk *way* faster // than this hashtable leaks memory. -static auto& module_id_to_step_number GUARDED_BY(mu) = +static auto& module_id_to_step_number TF_GUARDED_BY(mu) = *new absl::flat_hash_map(); // Maps a module's unique ID to a timestamp indicating when we've first dumped @@ -285,7 +285,7 @@ static auto& module_id_to_step_number GUARDED_BY(mu) = // dies. But we only add an entry if dumping is enabled for this module, and // dumping a module leaks buffer space in stdout or bytes on disk *way* faster // than this hashtable leaks memory. -static auto& module_id_to_timestamp GUARDED_BY(mu) = +static auto& module_id_to_timestamp TF_GUARDED_BY(mu) = *new absl::flat_hash_map(); int64 StepNumberForModule(const HloModule& module) { @@ -432,7 +432,7 @@ void DumpHloSnapshotIfEnabled(const HloModule& module, int64 execution_count; uint64 timestamp; { - static auto& module_id_to_execution_count GUARDED_BY(mu) = + static auto& module_id_to_execution_count TF_GUARDED_BY(mu) = *new absl::flat_hash_map(); tensorflow::mutex_lock lock(mu); execution_count = module_id_to_execution_count[module.unique_id()]++; @@ -469,7 +469,7 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, // have to use its name. int64 execution_count; { - static auto& module_name_to_execution_count GUARDED_BY(mu) = + static auto& module_name_to_execution_count TF_GUARDED_BY(mu) = *new absl::flat_hash_map(); tensorflow::mutex_lock lock(mu); execution_count = module_name_to_execution_count[name]++; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index dfcf50a3108..34d144ea1e9 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -584,10 +584,13 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { HloInstruction* operand_dynamic_size, DimensionConstraint constraint) -> Status { HloInstruction* reshape = hlo; - TF_RET_CHECK(reshape->shape().rank() > 0) - << "Reshaping a dynamic dimension into a scalar, which has " - "undefined behavior. The offending instruction is: " - << reshape->ToString(); + if (reshape->shape().rank() == 0) { + VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has " + "undefined behavior when input size is 0. The offending " + "instruction is: " + << reshape->ToString(); + return Status::OK(); + } auto common_factors = CommonFactors(operand->shape().dimensions(), reshape->shape().dimensions()); int64 input_dim_start = -1; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 3a69a684b86..d2913f9d2a1 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -560,6 +560,30 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) { EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr); } +TEST_F(DynamicDimensionInferenceTest, ReshapeIntoScalar) { + // Test the ability to a reshape into scalar. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1}); + auto output_shape = ShapeUtil::MakeShape(F32, {}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + SCOPED_TRACE(module_->ToString()); + TF_CHECK_OK(RunInference()); +} + TEST_F(DynamicDimensionInferenceTest, GatherTest) { const string hlo_text = R"( HloModule TensorFlowGatherV2 diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 4e9b9f883e2..8819b9da922 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -86,12 +86,12 @@ class ExecutionTracker { private: // The next handle to assign to an execution. - int64 next_handle_ GUARDED_BY(execution_mutex_); + int64 next_handle_ TF_GUARDED_BY(execution_mutex_); // Mapping from ExecutionHandle handle to the corresponding registered // AsyncExecution object. std::map> handle_to_execution_ - GUARDED_BY(execution_mutex_); + TF_GUARDED_BY(execution_mutex_); tensorflow::mutex execution_mutex_; // Guards the execution mapping. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ba5d7e9d788..4a903548c22 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -156,6 +156,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/core/platform:random", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h index 8ef5a46b3e3..50ecca51588 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_ -#include "absl/base/thread_annotations.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/blas.h" namespace xla { @@ -66,7 +66,8 @@ class CholeskyThunk : public Thunk { const int64 n_; tensorflow::mutex mu_; - absl::flat_hash_map contexts_ GUARDED_BY(mu_); + absl::flat_hash_map contexts_ + TF_GUARDED_BY(mu_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 2a071cd658d..a5001d5168d 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -127,12 +127,12 @@ class Rendezvous { std::make_shared(key_.num_participants)}; tensorflow::mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; + bool initialized_ TF_GUARDED_BY(mu_) = false; // We use an std::map so that we can iterate over it below in a guaranteed // order. The order shouldn't actually matter, but why be nondeterministic if // we don't have to be? - std::map participants_ GUARDED_BY(mu_); + std::map participants_ TF_GUARDED_BY(mu_); }; void EnqueueCopy(se::DeviceMemoryBase src, se::Stream* src_stream, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index de67b115ff7..8316cb7d12d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -45,11 +45,11 @@ using GemmCacheKey = std::tuple; static tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED); -static auto& autotune_cache GUARDED_BY(autotune_cache_mu) = +static auto& autotune_cache TF_GUARDED_BY(autotune_cache_mu) = *new absl::flat_hash_map>(); -static int64 cache_hits GUARDED_BY(autotune_cache_mu) = 0; -static int64 cache_misses GUARDED_BY(autotune_cache_mu) = 0; +static int64 cache_hits TF_GUARDED_BY(autotune_cache_mu) = 0; +static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0; // Experimentally tries to pick the best algorithm for the given gemm. // @@ -58,15 +58,48 @@ static int64 cache_misses GUARDED_BY(autotune_cache_mu) = 0; // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at // all. static StatusOr> DoUncachedGemmAutotune( - const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - const se::RedzoneAllocator& allocator, const BufferComparator& comparator, - bool crash_on_checking_failure) { + const HloInstruction* gemm, se::Stream* stream, + se::DeviceMemoryAllocator* allocator) { if (!stream->parent()->SynchronizeAllActivity()) { return InternalError("Failed to synchronize GPU for autotuning."); } + const HloModuleConfig& hlo_module_config = gemm->GetModule()->config(); + const bool init_cublas_data = + hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; + se::RedzoneAllocator input_output_allocator( + stream, allocator, PtxOptsFromConfig(hlo_module_config), + /*memory_limit=*/std::numeric_limits::max()); + + BufferComparator comparator(gemm->shape(), hlo_module_config); + + int64 rng_state = 0; + auto get_initialized_buffer = + [&](const HloInstruction* op) -> StatusOr { + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, + input_output_allocator.AllocateBytes( + ShapeUtil::ByteSizeOf(op->shape()))); + if (init_cublas_data) { + InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + } + return buffer; + }; + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer, + get_initialized_buffer(gemm->operand(0))); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer, + get_initialized_buffer(gemm->operand(1))); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer, + get_initialized_buffer(gemm)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer, + get_initialized_buffer(gemm)); + + const DebugOptions& debug_options = + gemm->GetModule()->config().debug_options(); + + const bool crash_on_checking_failure = + debug_options.xla_gpu_crash_on_verification_failures(); + GemmBackendConfig backend_config = gemm->backend_config().ValueOrDie(); const int32 cublas_autotune_level = @@ -124,7 +157,7 @@ static StatusOr> DoUncachedGemmAutotune( TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones()); + input_output_allocator.CheckRedzones()); if (!rz_check_status.ok()) { result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); *result.mutable_failure()->mutable_msg() = @@ -194,17 +227,14 @@ static StatusOr> DoUncachedGemmAutotune( } static StatusOr> DoGemmAutotune( - const HloInstruction* instr, const HloInstruction* lhs, - const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - bool crash_on_checking_failure, const se::RedzoneAllocator& allocator, - const BufferComparator& comparator) { + const HloInstruction* instr, const GemmBackendConfig& gemm_config, + se::DeviceMemoryAllocator* allocator, se::Stream* stream) { + const HloInstruction* lhs = instr->operand(0); + const HloInstruction* rhs = instr->operand(1); + // Don't run autotuning concurrently on the same GPU. tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent()); - GemmBackendConfig gemm_config = - instr->backend_config().ValueOrDie(); GemmCacheKey key = std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(), @@ -235,11 +265,8 @@ static StatusOr> DoGemmAutotune( VLOG(2) << "Batch size is non-singular, using generic algorithm"; result = absl::nullopt; } else { - TF_ASSIGN_OR_RETURN( - result, - DoUncachedGemmAutotune(instr, lhs_buffer, rhs_buffer, output_buffer, - reference_result_buffer, stream, allocator, - comparator, crash_on_checking_failure)); + TF_ASSIGN_OR_RETURN(result, + DoUncachedGemmAutotune(instr, stream, allocator)); } CHECK(autotune_cache.emplace(key, result).second); @@ -255,52 +282,11 @@ static StatusOr RunOnInstruction(HloInstruction* instr, TF_ASSIGN_OR_RETURN(se::Stream* const stream, allocator->GetStream(executor->device_ordinal())); - const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - const bool init_cublas_data = - hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; - se::RedzoneAllocator input_output_allocator( - stream, allocator, PtxOptsFromConfig(hlo_module_config), - /*memory_limit=*/std::numeric_limits::max()); - - BufferComparator comparator(instr->shape(), hlo_module_config); - - int64 rng_state = 0; - auto get_initialized_buffer = - [&](const HloInstruction* op) -> StatusOr { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - input_output_allocator.AllocateBytes( - ShapeUtil::ByteSizeOf(op->shape()))); - if (init_cublas_data) { - InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); - } - return buffer; - }; - GemmBackendConfig gemm_config = instr->backend_config().ValueOrDie(); - const HloInstruction* lhs = instr->operand(0); - const HloInstruction* rhs = instr->operand(1); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer, - get_initialized_buffer(lhs)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer, - get_initialized_buffer(rhs)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer, - get_initialized_buffer(instr)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer, - get_initialized_buffer(instr)); - - const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); - - const bool crash_on_checking_failure = - debug_options.xla_gpu_crash_on_verification_failures(); - - TF_ASSIGN_OR_RETURN( - absl::optional gemm_algorithm, - DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer, - reference_result_buffer, stream, crash_on_checking_failure, - input_output_allocator, comparator)); + TF_ASSIGN_OR_RETURN(absl::optional gemm_algorithm, + DoGemmAutotune(instr, gemm_config, allocator, stream)); // We update instruction->backend_config(); if no algorithms are supported, // a different API is used, which does not require specifying an algorithm. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 51b30e238e9..3a8c3321e24 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -343,11 +343,14 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a // right way to share this. static bool RequireDeterminism() { - bool deterministic_ops = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", - /*default_val=*/false, - &deterministic_ops)); - return deterministic_ops; + static bool require_determinism = [] { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + return deterministic_ops; + }(); + return require_determinism; } Status GpuCompiler::OptimizeHloPostLayoutAssignment( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 31ace1a416e..7a7d2e1e1b1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -268,9 +268,9 @@ ConvCacheKey AutotuneCacheKeyfromInstruction( } tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED); -auto& autotune_cache GUARDED_BY(autotune_cache_lock) = +auto& autotune_cache TF_GUARDED_BY(autotune_cache_lock) = *new absl::flat_hash_map(); -auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) = +auto& autotune_cache_stats TF_GUARDED_BY(autotune_cache_lock) = *new ConvCacheStats(); } // anonymous namespace diff --git a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h index 41825a33174..c1b83158dfb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h @@ -126,17 +126,17 @@ class GpuDebugInfoManager { }; tensorflow::mutex mutex_; - bool tracing_active_ GUARDED_BY(mutex_) = false; + bool tracing_active_ TF_GUARDED_BY(mutex_) = false; // Modules that was running currently. Because multiple instances of the // modules can be running in the same time, a reference count is maintained // as map value. absl::flat_hash_map running_module_ids_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); // Active modules are those still tracked by us. There could be much more // active modules than running modules, we will try to reduce the trace size // by only transfer those modules that were running during tracing period. absl::flat_hash_map active_modules_ - GUARDED_BY(mutex_); + TF_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 1f601712038..50d27182df1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -161,6 +161,9 @@ Status GpuExecutable::ExecuteThunks( sub_streams.emplace_back(); TF_ASSIGN_OR_RETURN(sub_streams.back(), run_options->BorrowStream(executor->device_ordinal())); + // Require substreams to wait for the main stream, otherwise substreams may + // execute before the program is scheduled to start on the main stream. + sub_streams.back()->ThenWaitFor(main_stream); } HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 3d3afe6168b..33642a7dc3d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -159,9 +159,9 @@ class GpuExecutable : public Executable { // `ResolveConstantGlobals`. tensorflow::mutex module_handle_mutex_; std::map - module_handles_ GUARDED_BY(module_handle_mutex_); + module_handles_ TF_GUARDED_BY(module_handle_mutex_); std::map - module_globals_ GUARDED_BY(module_handle_mutex_); + module_globals_ TF_GUARDED_BY(module_handle_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 50ef3905495..88351881f3a 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -84,7 +84,7 @@ class KernelThunk : public Thunk { // Loaded kernels for each `StreamExecutor`. Requires pointer stability of // values. std::unordered_map> - kernel_cache_ GUARDED_BY(mutex_); + kernel_cache_ TF_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 52c4fb93199..8d568e7f5d4 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -492,7 +492,7 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr handle, // lives, which is how we avoid expensive reinitialization of NCCL cliques. struct NcclAllReduceThunk::AuxData { tensorflow::mutex mu; - absl::flat_hash_set> cliques GUARDED_BY(mu); + absl::flat_hash_set> cliques TF_GUARDED_BY(mu); }; /*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 6d036094a69..4f46e292210 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -148,18 +148,23 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { + HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1"); + // Pad the dimensions of matrices in dot operations to multiples of 8. + // This needs to run before GemmRewriter, which is part of + // OptimizeHloPostLayoutAssignment(). + if (IsVoltaOrLater(*stream_exec)) { + pre_pipeline.AddPass(); + } + TF_RETURN_IF_ERROR(pre_pipeline.Run(hlo_module).status()); + TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment( hlo_module, stream_exec, device_allocator)); - HloPassPipeline pipeline("nvptx post-layout_assignment"); - // Pad the dimensions of matrices in dot operations to multiples of 8. - if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - } + HloPassPipeline post_pipeline("nvptx post-layout_assignment part 2"); // Find the fastest algorithm for GEMMs. - pipeline.AddPass(stream_exec, device_allocator); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + post_pipeline.AddPass(stream_exec, device_allocator); + TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); return Status::OK(); } @@ -228,13 +233,14 @@ bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { // and warn when a file is not used to ease catching typo in filename. std::string prefix = xla::FilenameFor(*module, "", *ptx); std::string matched_filename; - for (const string filename : + for (const string full_filename : module->config().debug_options().xla_gpu_ptx_file()) { // To ease comparing many PTX versions, accept different suffixes then // the original filename. + auto filename = tensorflow::io::Basename(full_filename); if (absl::StartsWith(filename, prefix)) { - matched_filename = filename; - VLOG(0) << "RunBackend() - Will load PTX from file: " << filename; + matched_filename = full_filename; + VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename; break; } } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 3098d5af25f..e69be947522 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -62,8 +62,8 @@ class NVPTXCompiler : public GpuCompiler { // We cache the cuda_data_dir() and the result of our search, so that if the // next module we have to compile has the same cuda_data_dir(), we can skip // the search. - string cached_cuda_data_dir_ GUARDED_BY(mutex_); - string cached_libdevice_dir_ GUARDED_BY(mutex_); + string cached_cuda_data_dir_ TF_GUARDED_BY(mutex_); + string cached_libdevice_dir_ TF_GUARDED_BY(mutex_); // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. @@ -116,7 +116,7 @@ class NVPTXCompiler : public GpuCompiler { // is critical here. absl::node_hash_map - compilation_cache_ GUARDED_BY(mutex_); + compilation_cache_ TF_GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); }; diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index a92a5783b67..d9a5463013d 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/core/platform/random.h" namespace xla { namespace gpu { @@ -72,13 +73,17 @@ int ComputeStreamToAssign( return kInvalidStreamNum; } - if (hlo.GetModule() - ->config() - .debug_options() - .xla_gpu_disable_multi_streaming()) { + const auto& debug_options = hlo.GetModule()->config().debug_options(); + if (debug_options.xla_gpu_disable_multi_streaming()) { return 0; } + if (debug_options.xla_gpu_use_random_streams()) { + // Debug feature: make random stream assignments to try to uncover + // concurrency bugs. + return tensorflow::random::New64() % 100; + } + if (!(IsCublasGemm(hlo) || IsMatrixMultiplication(hlo))) { // If `hlo` is not implemented as a GEMM, keep it close to its operands to // avoid excessive synchronization. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h index 52d38b6f20e..1bcbec06921 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -30,7 +30,7 @@ class StreamAssignment { int StreamNumberForHlo(const HloInstruction& hlo) const; bool HasStreamAssigned(const HloInstruction& hlo) const; // `hlo` needs to outlive this StreamAssignment object. - void AssignStreamToHlo(const HloInstruction* hlo, int stream_no); + void AssignStreamToHlo(const HloInstruction* hlo, int stream_num); private: int stream_count_ = 1; // At least the main stream. diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 23b27a5f67d..a573b621c88 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" @@ -35,7 +37,8 @@ limitations under the License. namespace xla { -StatusOr HloDCE::RunOnComputation(HloComputation* computation) { +StatusOr HloDCE::RunOnComputation( + HloComputation* computation, bool remove_cross_partition_collective_ops) { bool changed = false; VLOG(3) << "Before dce:"; XLA_VLOG_LINES(3, computation->ToString()); @@ -47,7 +50,12 @@ StatusOr HloDCE::RunOnComputation(HloComputation* computation) { if (instruction != computation->root_instruction() && instruction->user_count() == 0 && computation->IsSafelyRemovable(instruction) && - !instruction->HasSideEffect()) { + (!instruction->HasSideEffect() || + (remove_cross_partition_collective_ops && + ((instruction->opcode() == HloOpcode::kAllReduce && + !Cast(instruction)->constrain_layout()) || + instruction->opcode() == HloOpcode::kCollectivePermute || + instruction->opcode() == HloOpcode::kAllToAll)))) { dead_roots.push_back(instruction); } } @@ -74,8 +82,9 @@ StatusOr HloDCE::Run(HloModule* module) { // Run DCE on each computation. for (auto* computation : module->MakeComputationPostOrder()) { - TF_ASSIGN_OR_RETURN(bool changed_for_computation, - RunOnComputation(computation)); + TF_ASSIGN_OR_RETURN( + bool changed_for_computation, + RunOnComputation(computation, remove_cross_partition_collective_ops_)); changed |= changed_for_computation; } diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index f22f98868ab..49bb2e3f139 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -35,15 +35,23 @@ namespace xla { // instructions cannot be deleted. class HloDCE : public HloModulePass { public: + HloDCE() : remove_cross_partition_collective_ops_(false) {} + explicit HloDCE(bool remove_cross_partition_collective_ops) + : remove_cross_partition_collective_ops_( + remove_cross_partition_collective_ops) {} ~HloDCE() override {} absl::string_view name() const override { return "dce"; } // Run DCE on a computation. - static StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnComputation(HloComputation* computation, + bool remove_cross_partition_collective_ops); // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). StatusOr Run(HloModule* module) override; + + private: + bool remove_cross_partition_collective_ops_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 2e205606977..78e4d39d3fe 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1557,7 +1557,7 @@ string WrapDotInHtml(absl::string_view dot) { tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED); std::function(absl::string_view)>* url_renderer - GUARDED_BY(url_renderer_mu) = nullptr; + TF_GUARDED_BY(url_renderer_mu) = nullptr; // Precondition: url_renderer != nullptr. // @@ -1567,7 +1567,7 @@ std::function(absl::string_view)>* url_renderer // of producing dot for the graph.) StatusOr WrapDotInFormat(absl::string_view dot, RenderedGraphFormat format) - EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { + TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { switch (format) { case RenderedGraphFormat::kUrl: CHECK(url_renderer != nullptr) diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 70cbdad9ca7..c8a68db25d4 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -225,7 +225,7 @@ string HloModule::ToString(const HloPrintOptions& options) const { } s << "\n\n"; const auto& computations = options.canonicalize_computations() - ? MakeComputationSortedByContent() + ? MakeComputationSorted() : MakeComputationPostOrder(); for (const HloComputation* computation : computations) { if (!options.print_computation(computation)) { @@ -602,16 +602,23 @@ std::vector HloModule::MakeComputationPostOrder() const { return post_order; } -std::vector HloModule::MakeComputationSortedByContent() const { - auto result = MakeComputationPostOrder(); - std::sort(result.begin(), result.end(), - [](HloComputation* a, HloComputation* b) { - if (a->instruction_count() != b->instruction_count()) { - return a->instruction_count() < b->instruction_count(); - } - return a->ToString(HloPrintOptions::Fingerprint()) < - b->ToString(HloPrintOptions::Fingerprint()); - }); +namespace { +bool CompareComputationsByContent(HloComputation* a, HloComputation* b) { + if (a->instruction_count() != b->instruction_count()) { + return a->instruction_count() < b->instruction_count(); + } + return a->ToString(HloPrintOptions::Fingerprint()) < + b->ToString(HloPrintOptions::Fingerprint()); +} +} // anonymous namespace + +std::vector HloModule::MakeComputationSorted() const { + std::vector result; + result.reserve(computations_.size()); + for (const auto& computation : computations_) { + result.push_back(computation.get()); + } + std::sort(result.begin(), result.end(), CompareComputationsByContent); return result; } @@ -629,10 +636,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::vector HloModule::MakeNonfusionComputationsSorted() const { auto result = MakeNonfusionComputations(); - std::sort(result.begin(), result.end(), - [](HloComputation* a, HloComputation* b) { - return a->name() < b->name(); - }); + std::sort(result.begin(), result.end(), CompareComputationsByContent); return result; } @@ -717,10 +721,10 @@ HloComputation* HloModule::GetComputationWithName(absl::string_view name) { uint64 HloModule::Hash() const { uint64 result = entry_computation_layout().Hash(); - // Use MakeComputationSortedByContent() instead of MakeComputationPostOrder() + // Use MakeComputationSorted() instead of MakeComputationPostOrder() // because naming may affect the order of MakeComputationPostOrder() but not - // MakeComputationSortedByContent(). - for (auto* computation : MakeComputationSortedByContent()) { + // MakeComputationSorted(). + for (auto* computation : MakeComputationSorted()) { for (auto* instruction : computation->MakeInstructionPostOrder()) { result = tensorflow::Hash64Combine(result, instruction->Hash()); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index e44e22ba954..38395f173e1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -197,8 +197,8 @@ class HloModule { std::vector MakeComputationPostOrder() const; // Same as MakeComputationPostOrder() but sorting the computations by their - // contents. - std::vector MakeComputationSortedByContent() const; + // contents. The order is longer post order. + std::vector MakeComputationSorted() const; // Gets the computations in this module which aren't for fusion nodes. // @@ -211,7 +211,7 @@ class HloModule { // MakeNonfusionComputations(). std::vector MakeNonfusionComputations() const; - // Same as MakeNonfusionComputations() but sorting the computations by names. + // Same as MakeNonfusionComputations() but sorting computations by content. std::vector MakeNonfusionComputationsSorted() const; const HloModuleConfig& config() const { return config_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index dee601d9e96..d90a1485441 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -183,6 +183,12 @@ class HloModuleConfig { return &fusion_config_; } + const std::vector>& dot_config() const { + return dot_config_; + } + + std::vector>* mutable_dot_config() { return &dot_config_; } + private: // If you add new members, be sure to update compilation_cache_key. @@ -213,7 +219,14 @@ class HloModuleConfig { FusionConfigCollection fusion_config_collection_ = FusionConfigCollection::kOff; + // Custom fusion configuration, where fusion_config_[c][v] control if node v + // in computation c must be fused to all its consumers (true) or not (false). std::vector> fusion_config_; + + // Custom dot canonicalization configuration, where dot_config_[v] control + // how to convert dot operation v (sorted topologically and by computation) to + // convolution. + std::vector> dot_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 21be4216469..bfc6769660a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1648,6 +1648,8 @@ StatusOr HloRematerialization::RematerializeComputation( } else { // Found a valid block. Reset to start looking for single instructions // again. + max_rematerialized_block_size_ = + std::max(max_rematerialized_block_size_, max_block_size); changed = true; min_block_size = 1; max_block_size = 1; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index d1c4b8b5e7b..72221fa8a32 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -180,6 +180,10 @@ class HloRematerialization : public HloModulePass { // dead. Hence, no net instructions were added. int64 net_instructions_added_ = 0; + // Size of the largest block that has been rematerialized. This is actually an + // upper bound (within a factor of 2) on the block size. + int max_rematerialized_block_size_ = 0; + RematerializationMode mode_; }; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 552c8eb1ae5..7a4eefc1ab6 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -71,11 +71,36 @@ cc_library( ), ) +cc_library( + name = "executable_base", + srcs = ["executable_base.cc"], + hdrs = ["executable_base.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/service:dynamic_dimension_inference", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], hdrs = ["executable.h"], deps = [ + ":executable_base", ":executor", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 725cb437f8c..cc7fdeaf0f6 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/interpreter/executable_base.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -41,8 +42,7 @@ InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module, std::unique_ptr evaluator, absl::optional dynamic_dymension_inference) - : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, - /*hlo_profile_index_map=*/nullptr), + : InterpreterExecutableBase(std::move(hlo_module)), evaluator_(std::move(evaluator)), dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) { if (dynamic_dimension_inference_.has_value()) { @@ -51,107 +51,12 @@ InterpreterExecutable::InterpreterExecutable( } } -InterpreterExecutable::~InterpreterExecutable() {} - -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - std::vector arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - const se::Platform* platform = executor->platform(); - - // Convert the ShapeTree to a ShapedBuffer. We do this so we can call - // TransferManager methods below. - std::vector argument_buffers; - argument_buffers.reserve(arguments.size()); - for (auto& argument : arguments) { - const ShapeTree& buffers = argument.Buffers(); - argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), - /*platform=*/nullptr, - /*device_ordinal=*/0)); - auto in_it = buffers.begin(); - auto out_it = argument_buffers.back().buffers().begin(); - for (; in_it != buffers.end(); ++in_it, ++out_it) { - out_it->second = in_it->second.AsDeviceMemoryBase(); - } - } - - VLOG(1) << "Execute " << module().name(); - if (VLOG_IS_ON(2)) { - for (const auto& a : argument_buffers) { - VLOG(2) << "-- argument " << a; - } - } - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - - const HloComputation* computation = module().entry_computation(); - if (computation->num_parameters() != arguments.size()) { - return tensorflow::errors::Internal( - "Mismatch between argument count and graph parameter count."); - } - - // Check that the args have the right shape. - for (int64 i = 0; i < computation->num_parameters(); ++i) { - const auto& expected_shape = computation->parameter_instruction(i)->shape(); - const auto& actual_shape = argument_buffers[i].on_device_shape(); - if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, - actual_shape)) { - return InvalidArgument( - "Shape mismatch on parameter %d. Expected %s, but was %s.", i, - ShapeUtil::HumanStringWithLayout(expected_shape), - ShapeUtil::HumanStringWithLayout(actual_shape)); - } - } - - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(platform)); - - // Transform the ShapedBuffer arguments into literals which the evaluator - // consumes. - std::vector arg_literals; - for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(Literal arg_literal, - transfer_manager->TransferLiteralFromDevice( - run_options->stream(), argument_buffers[p])); - arg_literals.push_back(std::move(arg_literal)); - } - +StatusOr InterpreterExecutable::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { // Execute the graph using the HloEvaluator. - Literal result_literal; - { - tensorflow::mutex_lock lock(evaluator_lock_); - evaluator_->ResetVisitStates(); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate(*computation, arg_literals)); - } - - // Transform the result literal back into a ShapedBuffer. - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers, - transfer_manager->AllocateScopedShapedBuffer( - result_literal.shape(), run_options->allocator(), - executor->device_ordinal())); - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), result_literal, result_buffers)); - ExecutionOutput result(std::move(result_buffers)); - - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - - ExecutionProfile* profile = run_options->run_options().execution_profile(); - if (profile) { - const double nanoseconds = (end_micros - start_micros) * 1000.0; - profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); - } - for (auto& argument : arguments) { - for (auto& index_buffer : *argument.MutableBuffers()) { - auto maybe_owning_buffer = index_buffer.second.Release(); - if (maybe_owning_buffer) { - result.AddToBeReleased(std::move(*maybe_owning_buffer)); - } - } - } - return std::move(result); + tensorflow::mutex_lock lock(evaluator_lock_); + evaluator_->ResetVisitStates(); + return evaluator_->Evaluate(computation, arg_literals); } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 5b2f41a884c..ce68a8472f5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/interpreter/executable_base.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,25 +41,22 @@ namespace interpreter { // Responsible for running a HLO graph through the HloEvaluator and output // buffer allocation. Refer to interpreter/README.md for more. -class InterpreterExecutable : public Executable { +class InterpreterExecutable : public InterpreterExecutableBase { public: InterpreterExecutable( std::unique_ptr hlo_module, std::unique_ptr evaluator, absl::optional dynamic_dymension_inference); - ~InterpreterExecutable() override; - - StatusOr ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - std::vector arguments, - HloExecutionProfile* hlo_execution_profile) override - LOCKS_EXCLUDED(evaluator_lock_); static int64 ShapeSizeBytes(const Shape& shape); protected: + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) override + TF_LOCKS_EXCLUDED(evaluator_lock_); + // The interpreter interprets executables with an HloEvaluator. - std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + std::unique_ptr evaluator_ TF_PT_GUARDED_BY(evaluator_lock_); mutable tensorflow::mutex evaluator_lock_; private: diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc new file mode 100644 index 00000000000..5850cbf005b --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/interpreter/executable_base.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" + +namespace xla { +namespace interpreter { + +InterpreterExecutableBase::InterpreterExecutableBase( + std::unique_ptr hlo_module) + : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, + /*hlo_profile_index_map=*/nullptr) {} + +StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + const se::Platform* platform = executor->platform(); + + // Convert the ShapeTree to a ShapedBuffer. We do this so we can call + // TransferManager methods below. + std::vector argument_buffers; + argument_buffers.reserve(arguments.size()); + for (auto& argument : arguments) { + const ShapeTree& buffers = argument.Buffers(); + argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), + /*platform=*/nullptr, + /*device_ordinal=*/0)); + auto in_it = buffers.begin(); + auto out_it = argument_buffers.back().buffers().begin(); + for (; in_it != buffers.end(); ++in_it, ++out_it) { + out_it->second = in_it->second.AsDeviceMemoryBase(); + } + } + + VLOG(1) << "Execute " << module().name(); + if (VLOG_IS_ON(2)) { + for (const auto& a : argument_buffers) { + VLOG(2) << "-- argument " << a; + } + } + + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + const HloComputation* computation = module().entry_computation(); + if (computation->num_parameters() != arguments.size()) { + return tensorflow::errors::Internal( + "Mismatch between argument count and graph parameter count."); + } + + // Check that the args have the right shape. + for (int64 i = 0; i < computation->num_parameters(); ++i) { + const auto& expected_shape = computation->parameter_instruction(i)->shape(); + const auto& actual_shape = argument_buffers[i].on_device_shape(); + if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, + actual_shape)) { + return InvalidArgument( + "Shape mismatch on parameter %d. Expected %s, but was %s.", i, + ShapeUtil::HumanStringWithLayout(expected_shape), + ShapeUtil::HumanStringWithLayout(actual_shape)); + } + } + + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); + + // Transform the ShapedBuffer arguments into literals which the evaluator + // consumes. + std::vector arg_literals; + for (int64 p = 0; p < computation->num_parameters(); ++p) { + TF_ASSIGN_OR_RETURN(Literal arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), argument_buffers[p])); + arg_literals.push_back(std::move(arg_literal)); + } + + TF_ASSIGN_OR_RETURN(Literal result_literal, + Evaluate(*computation, arg_literals)); + + // Transform the result literal back into a ShapedBuffer. + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers, + transfer_manager->AllocateScopedShapedBuffer( + result_literal.shape(), run_options->allocator(), + executor->device_ordinal())); + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + run_options->stream(), result_literal, result_buffers)); + ExecutionOutput result(std::move(result_buffers)); + + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + + ExecutionProfile* profile = run_options->run_options().execution_profile(); + if (profile) { + const double nanoseconds = (end_micros - start_micros) * 1000.0; + profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); + } + for (auto& argument : arguments) { + for (auto& index_buffer : *argument.MutableBuffers()) { + auto maybe_owning_buffer = index_buffer.second.Release(); + if (maybe_owning_buffer) { + result.AddToBeReleased(std::move(*maybe_owning_buffer)); + } + } + } + return std::move(result); +} + +} // namespace interpreter +} // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.h b/tensorflow/compiler/xla/service/interpreter/executable_base.h new file mode 100644 index 00000000000..a02ab7af8d0 --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla.pb.h" +namespace xla { +namespace interpreter { + +// Responsible for running a HLO graph through the HloEvaluator and output +// buffer allocation. Refer to interpreter/README.md for more. +class InterpreterExecutableBase : public Executable { + public: + explicit InterpreterExecutableBase(std::unique_ptr hlo_module); + + StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector arguments, + HloExecutionProfile* hlo_execution_profile) override; + + protected: + virtual StatusOr Evaluate( + const HloComputation& computation, + absl::Span arg_literals) = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutableBase); +}; + +} // namespace interpreter +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_BASE_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 2279be7d2e5..3c35fda55f1 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -130,19 +130,19 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { std::function callback) override; port::Status AllocateEvent(Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, ""}; + return port::Status::OK(); } port::Status DeallocateEvent(Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, ""}; + return port::Status::OK(); } port::Status RecordEvent(Stream *stream, Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, ""}; + return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"}; } port::Status WaitForEvent(Stream *stream, Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, ""}; + return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"}; } Event::Status PollForEventStatus(Event *event) override { diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 2ebb6ae567e..d30c24616ff 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -424,10 +424,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { aliased_allocation->chunk(), definition_time, definition_time)); } + std::vector use_times(uses.size()); + for (int i = 0; i < uses.size(); ++i) { + use_times[i] = instruction_schedule.at(uses[i].instruction); + } // Iterate over the uses. for (HloUse use : uses) { int64 use_time = instruction_schedule.at(use.instruction); - int64 last_use_time = instruction_schedule.at(uses.back().instruction); int64 latest_prefetch_time = use_time; if (use.instruction->parent() != defining_computation) { @@ -457,7 +460,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AllocationRequest request; request.start_time = definition_time; request.end_time = use_time; - request.last_use_time = last_use_time; + request.use_times = &use_times; request.latest_prefetch_time = latest_prefetch_time; request.use = use; request.buffer = value; @@ -692,7 +695,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( VLOG(2) << "Finding allocation for " << request.buffer->ToShortString() << " (" << request.start_time << ", " << request.end_time << ") latest prefetch = " << request.latest_prefetch_time - << " last use = " << request.last_use_time + << " last use = " << request.use_times->back() << " use = " << request.use.ToString() << ". Size = " << request.size << ", def pos = " << defining_position.ToString(); CHECK_LE(request.start_time, request.end_time); @@ -880,8 +883,8 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( // the last use time, we try to find an allocation that is available for the // entire Producer to Use2 range. absl::optional chunk_candidate = - FindBestNoCopyChunkCandidate(request.end_time, request.last_use_time, - preferred_offset, &alternate_mem_interval); + FindBestChunkCandidate(request.end_time, *request.use_times, + preferred_offset, &alternate_mem_interval); // Check if the new heap size fits within limits. Also ensure if a // preferred offset was provided, that offset was used. if (chunk_candidate) { @@ -1027,39 +1030,39 @@ bool AlternateMemoryBestFitHeap::Prefetch( BufferInterval alternate_mem_interval; alternate_mem_interval.buffer = request.buffer; alternate_mem_interval.size = request.size; - alternate_mem_interval.end = request.end_time; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); VLOG(4) << "Trying alternate memory allocation (" - << alternate_mem_interval.start << ", " - << alternate_mem_interval.end << ")"; + << alternate_mem_interval.start << ", " << request.end_time << ")"; // If this additional asynchronous copy would violate the limit, try a // different interval. if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, - alternate_mem_interval.end)) { + request.end_time)) { VLOG(4) << "This would violate the outstanding async copy limit."; continue; } if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, - alternate_mem_interval.end)) { + request.end_time)) { VLOG(4) << "This would violate asynchronous copy ordering."; continue; } - ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval); - // Check if the new heap size fits within limits. - if (chunk_candidate.heap_size <= available_heap_size()) { + auto chunk_candidate = FindBestChunkCandidate( + request.end_time, *request.use_times, + /*preferred_offset=*/absl::nullopt, &alternate_mem_interval); + // Check if we could find a suitable chunk. + if (chunk_candidate) { VLOG(3) << "Move the buffer to alternate memory at " << alternate_mem_interval.start - << ". Offset = " << chunk_candidate.chunk.offset - << ", size = " << chunk_candidate.chunk.size - << ", heap_size = " << chunk_candidate.heap_size + << ". Offset = " << chunk_candidate->chunk.offset + << ", size = " << chunk_candidate->chunk.size + << ", heap_size = " << chunk_candidate->heap_size << ", prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); - AddToPendingChunks(alternate_mem_interval, chunk_candidate); + AddToPendingChunks(alternate_mem_interval, *chunk_candidate); AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, - chunk_candidate.chunk, alternate_mem_interval.start, + chunk_candidate->chunk, alternate_mem_interval.start, request.end_time, request.latest_prefetch_time, request.allocations); @@ -1071,14 +1074,16 @@ bool AlternateMemoryBestFitHeap::Prefetch( } absl::optional -AlternateMemoryBestFitHeap::FindBestNoCopyChunkCandidate( - int64 end_time, int64 last_use_time, absl::optional preferred_offset, +AlternateMemoryBestFitHeap::FindBestChunkCandidate( + int64 end_time, const std::vector& use_times, + absl::optional preferred_offset, BufferInterval* alternate_mem_interval) const { if (!preferred_offset) { - // Find a chunk that's as long living as possible. - for (alternate_mem_interval->end = last_use_time; - alternate_mem_interval->end >= end_time; - --alternate_mem_interval->end) { + // Find a chunk that's as long living as possible iterating in reverse over + // the use times. + for (auto use_time = use_times.rbegin(); + use_time != use_times.rend() && *use_time >= end_time; ++use_time) { + alternate_mem_interval->end = *use_time; ChunkCandidate chunk_candidate = FindChunkCandidate(*alternate_mem_interval); if (chunk_candidate.heap_size <= available_heap_size()) { @@ -1086,6 +1091,7 @@ AlternateMemoryBestFitHeap::FindBestNoCopyChunkCandidate( return chunk_candidate; } } + alternate_mem_interval->end = end_time; return absl::nullopt; } // If a preferred offset is given, try to find an allocation at that offset diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index b056204b15f..51ff5329482 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -23,9 +23,10 @@ namespace xla { // This class contains pre-set assignments determined by memory space // assignment. It contains two data structures: (1) a chunks vector that maps a -// defining HloPosition to a Chunk (offset and size), and (2) a sizes vector -// that maps the memory space to its size. If there is only one alternate memory -// space like there is currently, there will be one entry in sizes. +// defining HloPosition to a Chunk (offset and size), and (2) an assignment_info +// vector that maps the memory space to information like its allocated size and +// heap memory trace. If there is only one alternate memory space like there is +// currently, there will be one entry in assignment_info. class PresetAssignments { public: // Contains per-memory-space information like the allocated size and heap @@ -639,13 +640,13 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Segment Segment Segment // // start_time and end_time are the start and end logical times of the segment. - // last_use_time is the time of the last use for this buffer (Use3 in the - // figure). latest_prefetch_time is the latest time we can schedule the - // CopyDone for a prefetch. + // use_times is a sorted sequence of the times of all uses. + // latest_prefetch_time is the latest time we can schedule the CopyDone for a + // prefetch. struct AllocationRequest { int64 start_time; int64 end_time; - int64 last_use_time; + const std::vector* use_times; int64 latest_prefetch_time; int64 size; HloUse use; @@ -696,11 +697,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const AllocationRequest& request, const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); - // For a no-copy allocation, find the best possible chunk candidate, where it - // has the longest possible availability if no preferred offset is given, or - // at the preferred_offset if it is given. - absl::optional FindBestNoCopyChunkCandidate( - int64 end_time, int64 last_use_time, + // Find the best possible chunk candidate, where it has the longest possible + // availability if no preferred offset is given, or at the preferred_offset if + // it is given. + absl::optional FindBestChunkCandidate( + int64 end_time, const std::vector& use_times, absl::optional preferred_offset, BufferInterval* alternate_mem_interval) const; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index afceefdeae6..abeeb866e8c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -67,13 +67,13 @@ cc_library( ":lhlo_dialect_emitter", "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TargetNVVMIR", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:compiler", @@ -148,6 +148,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:lhlo_copy_removal", "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", @@ -188,6 +189,7 @@ cc_library( ":failover_compiler", ":inject_errors_pass", ":mlir_compiler", + "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/tests:codegen_test_base", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 184d8d202c3..0914e5ef820 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -82,6 +82,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSign: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kSqrt: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSubtract: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kTanh: diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index ca26ae4e756..151d82fd2a1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -259,8 +259,8 @@ void EnableIRPrinting(mlir::PassManager* passManager) { auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { return VLOG_IS_ON(1); }; - passManager->enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + passManager->enableIRPrinting(/*shouldPrintBeforePass=*/enable_if_vlog_is_on, + /*shouldPrintAfterPass=*/{}, /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, llvm::dbgs()); passManager->disableMultithreading(); @@ -277,7 +277,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary Lhlo copies. - pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass()); // Transform lhlo operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. This will yield a single tiled loop nest where diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 75c7c284881..1f681bfab00 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -113,6 +113,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kSign: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kSqrt: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kSubtract: func_builder.create(loc, rets, args, attrs); break; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index aeaaf0b16c4..e2523d82b91 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -39,6 +39,7 @@ tf_cc_test( "compare.hlo", "const.hlo", "copy.hlo", + "copy_transpose.hlo", "cos.hlo", "exp.hlo", "fused_reduce.hlo", @@ -50,6 +51,7 @@ tf_cc_test( "rsqrt.hlo", "select.hlo", "sign.hlo", + "sqrt.hlo", "tanh.hlo", ], tags = tf_cuda_tests_tags() + ["no_rocm"], diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo index ec7df87af64..208ca2799b2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo @@ -10,9 +10,9 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { // CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]] // CHECK: } // CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]] -// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] -// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] -// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] +// CHECK-DAG: subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +// CHECK-DAG: subview %[[ARG1]]{{\[}}[[INDEX]]] +// CHECK-DAG: subview %[[ARG2]]{{\[}}[[INDEX]]] // CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]] // CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]] // CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo index e9000956c23..fe871c1feb6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo @@ -9,10 +9,10 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { } // CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) -// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] -// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] -// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] -// CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]] +// CHECK-DAG: subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +// CHECK-DAG: subview %[[ARG1]]{{\[}}[[INDEX]]] +// CHECK-DAG: subview %[[ARG2]]{{\[}}[[INDEX]]] +// CHECK-DAG: subview %[[RESULT]]{{\[}}[[INDEX]]] // CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] // CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] // CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo new file mode 100644 index 00000000000..2ad8c1b49e3 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo @@ -0,0 +1,12 @@ +HloModule CopyTranspose + +ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} { + %x = f32[2,4] parameter(0) + ROOT %copy = f32[2,4]{0,1} copy(f32[2,4] %x) +} + +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> +// CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, +// CHECK-SAME: %[[RESULT:.*]]: memref<2x4xf32, #[[MAP0]]>) +// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) +// CHECK-SAME: : (memref<2x4xf32>, memref<2x4xf32, #[[MAP0]]>) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 7afb7e9281d..206d46debdf 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -58,6 +58,13 @@ TEST_F(LhloGenTest, Copy) { "copy.hlo")); } +TEST_F(LhloGenTest, CopyTranspose) { + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "copy_transpose.hlo")); +} + TEST_F(LhloGenTest, Select) { CompileAndVerifyIr( /*hlo_text_filename=*/tensorflow::io::JoinPath( @@ -186,6 +193,13 @@ TEST_F(LhloGenTest, Sign) { "rsqrt.hlo")); } +TEST_F(LhloGenTest, Sqrt) { + CompileAndVerifyIr( + /*hlo_text_filename=*/tensorflow::io::JoinPath( + "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", + "sqrt.hlo")); +} + TEST_F(LhloGenTest, Tanh) { CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo new file mode 100644 index 00000000000..95461b912a3 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo @@ -0,0 +1,11 @@ +HloModule Sqrt + +ENTRY %Sqrt (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %sqrt = f32[2,2]{1,0} sqrt(f32[2,2]{1,0} %x) +} + +// CHECK: func @sqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +// CHECK: "xla_lhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: } + diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index db977aaa32b..febbf9294b0 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -1187,15 +1187,19 @@ class HloInstructionIsImpl { bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { if (inst != inst_) { - EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" - << InstToString(inst_) << ")"; + EXPLAIN << "HloInstruction " << std::hex << std::nouppercase + << std::showbase << reinterpret_cast(inst) << " is not " + << reinterpret_cast(inst_) << " (" << InstToString(inst_) + << ")"; return false; } return true; } void DescribeTo(std::ostream* os, int64 indent = 0) const { - *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; + *os << "which is " << std::hex << std::nouppercase << std::showbase + << reinterpret_cast(inst_) << " (" << InstToString(inst_) + << ")"; } private: diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h index 7221d323a61..9cc5b7c9cea 100644 --- a/tensorflow/compiler/xla/service/stream_pool.h +++ b/tensorflow/compiler/xla/service/stream_pool.h @@ -56,7 +56,7 @@ class StreamPool { void ReturnStream(se::Stream* stream); tensorflow::mutex mu_; - std::vector> streams_ GUARDED_BY(mu_); + std::vector> streams_ TF_GUARDED_BY(mu_); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index f95f982eb89..ac5e1b80651 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -39,6 +39,13 @@ class TransposeFolding : public HloModulePass { const OperandIndices&) { return {}; } + + // Helper function to always fold transposes. + static OperandIndices AlwaysFoldTranspose(const HloInstruction&, + const OperandIndices& ids) { + return ids; + } + // transposable_gemm_operands returns the set of operands it wants to fold if // the instruction argument is implemented as a GEMM kernel that supports // transposing its arguments. @@ -47,8 +54,10 @@ class TransposeFolding : public HloModulePass { // the instruction argument is implemented as a convolution that supports // transposing its arguments. explicit TransposeFolding( - TransposableGemmOperandsFn transposable_gemm_operands, - TransposableConvOperandsFn transposable_conv_operands); + TransposableGemmOperandsFn transposable_gemm_operands = + AlwaysFoldTranspose, + TransposableConvOperandsFn transposable_conv_operands = + AlwaysFoldTranspose); absl::string_view name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index de243431e2c..d1d5dc17083 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -48,7 +48,7 @@ Shape::Shape(const ShapeProto& shape_proto) { } tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { - *add_tuple_shapes() = Shape(element_shape); + tuple_shapes_.emplace_back(element_shape); } if (shape_proto.has_layout()) { *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 23010d6ce70..e42af57e19b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -939,6 +939,8 @@ xla_test( tags = [ "no_rocm", "optonly", + # TODO(b/151340488): Timed out on 2020-03-12. + "nozapfhahn", ], deps = [ ":client_library_test_base", @@ -1096,6 +1098,10 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], + backend_tags = { + # TODO(b/151340488): Timed out on 2020-03-12. + "interpreter": ["nozapfhahn"], + }, shard_count = 40, tags = [ "no_rocm", @@ -1134,7 +1140,11 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - tags = ["no_rocm"], + tags = [ + "no_rocm", + # TODO(b/151340488): Timed out on 2020-03-12. + "nozapfhahn", + ], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1489,6 +1499,10 @@ xla_test( name = "select_and_scatter_test", timeout = "long", srcs = ["select_and_scatter_test.cc"], + backend_tags = { + # TODO(b/151340488): Timed out on 2020-03-12. + "interpreter": ["nozapfhahn"], + }, tags = [ "no_rocm", "optonly", diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index c35f05ebf45..53c0d84854e 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/Support/Host.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 8908a855847..ea457024618 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -66,12 +66,12 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { mutable tensorflow::mutex count_mutex_; // Global counts of allocations and deallocations. - int64 allocation_count_ GUARDED_BY(count_mutex_) = 0; - int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0; + int64 allocation_count_ TF_GUARDED_BY(count_mutex_) = 0; + int64 deallocation_count_ TF_GUARDED_BY(count_mutex_) = 0; // Per-device counts of allocations and deallocations. - std::map device_allocation_count_ GUARDED_BY(count_mutex_); - std::map device_deallocation_count_ GUARDED_BY(count_mutex_); + std::map device_allocation_count_ TF_GUARDED_BY(count_mutex_); + std::map device_deallocation_count_ TF_GUARDED_BY(count_mutex_); }; // A base class for tests which exercise the LocalClient interface. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index d83ba25c345..3468c12d8c9 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -69,6 +69,10 @@ message DebugOptions { // Disable multi-streaming in the GPU backend. bool xla_gpu_disable_multi_streaming = 63; + // Debugging feature: if enabled, the GPU backend will assign HLO operators to + // randomly chosen streams. This is intended to trigger concurrency bugs. + bool xla_gpu_use_random_streams = 134; + // If true, in LLVM-based backends, emit !alias.scope metadata in // generated IR. bool xla_llvm_enable_alias_scope_metadata = 70; @@ -260,7 +264,8 @@ message DebugOptions { // Guarantee run-to-run determinism from reductions on XLA:GPU. bool xla_gpu_deterministic_reductions = 130; - // Next id: 134 + + // Next id: 135 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index b885a7593f5..47b7cda2760 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -10,11 +10,12 @@ import "tensorflow/compiler/xla/xla_data.proto"; message DeviceAssignment { message ComputationDevice { message DeviceMeshCoordinates { - // The mesh coordinates for the device. Usually (X, Y, Core), in the order - // in which they are returned in the TopologyProto. + // The mesh coordinates for the device. Usually (X, Y, Z, Core), in the + // order in which they are returned in the TopologyProto. // X = value(0) // Y = value(1) - // Core = value(2) + // Z = value(2) + // Core = value(3) repeated int32 value = 1; } // As many replicas as there are in the replicated computation. diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index 02cb25ea35c..3c2577e620b 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -173,11 +173,11 @@ class XRTCompilationCache : public ResourceBase { // last reference to entry is released, entry is removed from cache_. void DiscardEntryRef(CompiledSubgraph* entry); void DiscardEntryRefLocked(CompiledSubgraph* entry) - EXCLUSIVE_LOCKS_REQUIRED(mu_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Marks the oldest unmarked entry for eviction. Requires that there is at // least one such entry. - void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_); + void MarkOldestEntryForEviction() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Updates datastructures to indicate that entry, which had been marked for // eviction, has been looked up. This is called by CompileIfKeyAbsent when an @@ -195,7 +195,7 @@ class XRTCompilationCache : public ResourceBase { // is never marked for eviction, so an entry larger than the max cache entries // will remain in the cache until it is replaced by something else. void LookupEntryMarkedForEviction(CompiledSubgraph* entry) - EXCLUSIVE_LOCKS_REQUIRED(mu_); + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates a new entry by running initialize_program and places it in the // cache to be looked up by key. The new entry is in the 'marked for eviction' @@ -206,7 +206,7 @@ class XRTCompilationCache : public ResourceBase { CompiledSubgraph* InitializeEntry( const string& key, const std::function*)>& - initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_); + initialize_program) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // The maximum number of entries that are stored in the cache before entries // are marked for eviction. @@ -214,23 +214,24 @@ class XRTCompilationCache : public ResourceBase { mutable absl::Mutex mu_; // The total number of entries that are stored and not marked for eviction. - int cache_entries_ GUARDED_BY(mu_) = 0; + int cache_entries_ TF_GUARDED_BY(mu_) = 0; // The total number of entries that are marked for eviction. - int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0; + int marked_for_eviction_entries_ TF_GUARDED_BY(mu_) = 0; // The value to assign to the last_use field of the next entry that is looked // up. - int64 use_counter_ GUARDED_BY(mu_) = 0; + int64 use_counter_ TF_GUARDED_BY(mu_) = 0; // All the executables that can be looked up in the cache index by key. An // entry is marked for eviction iff it is present in cache_ and not in // entries_by_last_use_. - std::unordered_map cache_ GUARDED_BY(mu_); + std::unordered_map cache_ TF_GUARDED_BY(mu_); // All the executable entries that can be looked up in the cache indexed by // uid. - std::unordered_map entries_by_uid_ GUARDED_BY(mu_); + std::unordered_map entries_by_uid_ + TF_GUARDED_BY(mu_); // Map from last_use to entry, used to mark entries for eviction in LRU // order. If an entry's last_use counter is not present as a key in // entries_by_last_use_ then the entry has been marked for eviction. - std::map entries_by_last_use_ GUARDED_BY(mu_); + std::map entries_by_last_use_ TF_GUARDED_BY(mu_); }; // Looks up or create an XRTCompilationCache object within the given resource diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b02eb89ebfc..188988d92c4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -175,6 +175,9 @@ exports_files([ package_group(name = "experimental_access") +# Authorized users go here. +package_group(name = "friends") + # ----------------------------------------------------------------------------- # Public targets @@ -469,6 +472,7 @@ tf_cuda_library( "//tensorflow/core/framework:graph_to_functiondef.h", "//tensorflow/core/framework:kernel_def_builder.h", "//tensorflow/core/framework:kernel_def_util.h", + "//tensorflow/core/framework:kernel_shape_util.h", "//tensorflow/core/framework:log_memory.h", "//tensorflow/core/framework:logging.h", "//tensorflow/core/framework:lookup_interface.h", @@ -1020,7 +1024,6 @@ cc_library( "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", - "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", @@ -1095,6 +1098,7 @@ cc_library( "//tensorflow/core/kernels:mkl_tfconv_op", "//tensorflow/core/kernels:mkl_tmp_bf16_ops", ]) + if_cuda([ + "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_ops", ]) + if_nccl([ @@ -2252,7 +2256,6 @@ filegroup( "//tensorflow/core/framework:shared_ptr_variant.h", "//tensorflow/core/framework:tensor_reference.h", "//tensorflow/core/framework:tracking_allocator.h", # only needed for tests - "//tensorflow/core/framework:unique_tensor_references.h", "//tensorflow/core/framework:variant.h", "//tensorflow/core/util:framework_internal_public_hdrs", ], @@ -2351,6 +2354,7 @@ tf_cuda_library( "//tensorflow/core/framework:attr_value_util", "//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:common_shape_fns", + "//tensorflow/core/framework:kernel_shape_util", "//tensorflow/core/framework:node_def_util", "//tensorflow/core/framework:node_properties", "//tensorflow/core/framework:numeric_types", @@ -3213,6 +3217,49 @@ test_suite( ], ) +tf_cc_test( + name = "common_runtime_placer_test", + size = "small", + srcs = [ + "common_runtime/placer_test.cc", + ], + linkopts = select({ + "//tensorflow:macos": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_windows"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/cc:while_loop", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/util:protos_test_cc", + "//third_party/eigen3", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + tf_cc_tests( name = "core_higher_level_tests", size = "small", @@ -3232,7 +3279,6 @@ tf_cc_tests( "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", "common_runtime/placer_inspection_required_ops_utils_test.cc", - "common_runtime/placer_test.cc", "common_runtime/session_test.cc", "common_runtime/threadpool_device_test.cc", "//tensorflow/core/example:feature_util_test.cc", diff --git a/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt b/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt new file mode 100644 index 00000000000..d5fd0d609c8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "BeginEpoch" + visibility: HIDDEN + summary: "Begins a tf.data service dataset epoch." +} diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeGif.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeGif.pbtxt index 75278f3c806..68438bc8114 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeGif.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeGif.pbtxt @@ -21,6 +21,6 @@ uncompressed by running: convert $src.gif -coalesce $dst.gif This op also supports decoding JPEGs and PNGs, though it is cleaner to use -`tf.image.decode_image`. +`tf.io.decode_image`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeJpeg.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeJpeg.pbtxt index b9521370d35..e6147a00412 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeJpeg.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeJpeg.pbtxt @@ -75,6 +75,6 @@ downscaling the image later. This op also supports decoding PNGs and non-animated GIFs since the interface is -the same, though it is cleaner to use `tf.image.decode_image`. +the same, though it is cleaner to use `tf.io.decode_image`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_DecodePng.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodePng.pbtxt index 63404db8009..450de43751f 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodePng.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodePng.pbtxt @@ -34,6 +34,6 @@ If needed, the PNG-encoded image is transformed to match the requested number of color channels. This op also supports decoding JPEGs and non-animated GIFs since the interface -is the same, though it is cleaner to use `tf.image.decode_image`. +is the same, though it is cleaner to use `tf.io.decode_image`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_DistributeDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_DistributeDataset.pbtxt new file mode 100644 index 00000000000..a04f1e830c4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DistributeDataset.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "DataServiceDataset" + visibility: HIDDEN + summary: "Creates a dataset that reads data from the tf.data service." +} diff --git a/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt index 73d548b226d..a9d5b981576 100644 --- a/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt @@ -38,6 +38,12 @@ END name: "interpolation" description: <