diff --git a/RELEASE.md b/RELEASE.md index 6340db05345..0e3eb0e0271 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -61,6 +61,7 @@ * Added support for saved model's session initializer through `TFLiteConverter.from_saved_model`. * Added dynamic range quantization support for the BatchMatMul op. + * Added DEPTH_TO_SPACE support in Post training quantization. * Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently only supports float32 input. * TFLite Supports SingatureDef: diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 99a278a14a4..4bfc79d9939 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -145,6 +145,8 @@ if _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index ae82f7b4792..e69287afd46 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -155,6 +155,8 @@ if _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Delete modules that should be hidden from dir(). # Don't fail if these modules are not available. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 2ce9e9aeb87..b082dd81d64 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -684,7 +684,10 @@ tf_cc_test( name = "c_api_experimental_test", size = "medium", srcs = ["c_api_experimental_test.cc"], - data = ["testdata/tf_record"], + data = [ + "testdata/tf_record", + "//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so", + ], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -704,6 +707,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 0d188aa5ee0..e9734427bb0 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -37,7 +37,9 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/strcat.h" @@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, namespace tensorflow { Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); + +// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file). +Status LoadPluggableDeviceLibrary(const char* library_filename, void** result); } // namespace tensorflow void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, @@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptions* opts, unsigned char enable) { opts->opts.validate_colocation_constraints = enable; } + +// Load a Pluggable Device library. +// On success, returns the handle to library in result and return OK from the +// function. Otherwise return nullptr in result and error Status from the +// function. +// +// If `library_filename` has already been loaded, we return a cached handle. +// Device and Kernels/Ops are registered as globals when a library is loaded +// for the first time. +TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "PluggableDevice plugin functionality is not supported on mobile"); + return nullptr; +#else + TF_Library* lib_handle = new TF_Library; + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_map* loaded_libs = + new std::unordered_map(); + tensorflow::Env* env = tensorflow::Env::Default(); + { + tensorflow::mutex_lock lock(mu); + auto it = loaded_libs->find(library_filename); + if (it != loaded_libs->end()) { + lib_handle->lib_handle = it->second; + } else { + status->status = + env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle); + if (!status->status.ok()) { + delete lib_handle; + return nullptr; + } + } + return lib_handle; + } +#endif +} + +void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { + delete lib_handle; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index e877c775387..d4132153641 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptions* opts, unsigned char enable); +// Load the library specified by library_filename and register the pluggable +// device and related kernels present in that library. This function is not +// supported on embedded on mobile and embedded platforms and will fail if +// called. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, returns the newly created library handle and places OK in status. +// The caller owns the library handle. +// +// On failure, returns nullptr and places an error status in status. +TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary( + const char* library_filename, TF_Status* status); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle( + TF_Library* lib_handle); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index cfeba345f81..4c319d0a798 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { TF_DeleteTensor(tensor_1X6); } +TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) { +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + // Load the library. + TF_Status* status = TF_NewStatus(); + string lib_path = + tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath( + "tensorflow", "c", "experimental", "stream_executor", "test", + "test_pluggable_device.so")); + TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + TF_DeletePluggableDeviceLibraryHandle(lib); +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 3bae78dde5f..d37cb963dfb 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( TF_DeleteFunction(tf_function); return nullptr; } - tf_function->graph_with_debug_info = &fn_body->graph; + + for (const Node* n : fn_body->graph.nodes()) { + tf_function->stack_traces[n->name()] = n->GetStackTrace(); + } + return tf_function; } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index d45aa9a6356..b5ab775fef0 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -157,9 +157,7 @@ struct TF_DeviceList { struct TF_Function { tensorflow::FunctionDef fdef; - - // Graph with nodes with debug stack traces. - const tensorflow::Graph* graph_with_debug_info = nullptr; + tensorflow::StackTracesMap stack_traces; }; struct TF_ApiDefMap { diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 7738558c43f..c768c776a49 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -749,8 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo( - function->fdef, function->graph_with_debug_info); + status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces( + function->fdef, function->stack_traces); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 88696fc7eec..e557753c49d 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext { // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; - // Same as `AddFunctionDef`, and additionally saves a pointer to the Graph - // which has nodes containing stack traces for the nodes in `fdef`. Assumes - // `graph` is alive while the function is alive. - virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef, - const Graph* graph) = 0; + // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under + // the key of the function definition name (to be retrieved during function + // instantiation). + virtual Status AddFunctionDefWithStackTraces( + const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; // Find and return a added function by its name. virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD new file mode 100644 index 00000000000..ca8bdaf641d --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/test/BUILD @@ -0,0 +1,17 @@ +# Description: +# test for stream_executor +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_shared_object( + name = "test_pluggable_device.so", + srcs = ["test_pluggable_device.cc"], + visibility = ["//tensorflow/c:__subpackages__"], + deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"], +) diff --git a/tensorflow/compiler/xla/python/bfloat16.h b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc similarity index 61% rename from tensorflow/compiler/xla/python/bfloat16.h rename to tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc index 9e52d086919..d985f3ca83b 100644 --- a/tensorflow/compiler/xla/python/bfloat16.h +++ b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" -#include "pybind11/pybind11.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { - -xla::StatusOr Bfloat16Dtype(); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ +void SE_InitPlugin(SE_PlatformRegistrationParams* const params, + TF_Status* const status) { + params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE; + params->platform->name = "GPU"; + params->platform->type = "XGPU"; +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 27f98be14ad..d33a91b5898 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +using tensorflow::errors::InvalidArgument; // This file forms the basis of a stable ABI for third-party kernel // implementations. It is crucial that changes to this file are made cautiously // and with a focus on maintaining both source and binary compatibility. @@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, TF_SetStatus(status, TF_OK, ""); } #undef CASE + } // namespace } // namespace tensorflow +namespace { +const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx, + const char* attr_name, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + const tensorflow::AttrValue* attr = + ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name); + if (attr == nullptr) { + status->status = InvalidArgument("Operation '", cc_ctx->def().name(), + "' has no attr named '", attr_name, "'."); + } + return attr; +} +} // namespace + void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, const TF_DataType type, @@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { cc_ctx->CtxFailure(s); } -#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ +void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx, + const char* attr_name, + int32_t* list_size, + int32_t* total_size, + TF_Status* status) { + const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); + if (!status->status.ok()) { + *list_size = -1; + *total_size = -1; + return; + } + switch (attr->value_case()) { +#define SINGLE_CASE(kK, attr_type, size_expr) \ + case tensorflow::AttrValue::kK: \ + *list_size = -1; \ + *total_size = size_expr; \ + break; + + SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); + SINGLE_CASE(kI, TF_ATTR_INT, -1); + SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); + SINGLE_CASE(kB, TF_ATTR_BOOL, -1); + SINGLE_CASE(kType, TF_ATTR_TYPE, -1); + SINGLE_CASE(kShape, TF_ATTR_SHAPE, + attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); + SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); +#undef SINGLE_CASE + + case tensorflow::AttrValue::kList: + *list_size = 0; + *total_size = -1; +#define LIST_CASE(field, attr_type, ...) \ + if (attr->list().field##_size() > 0) { \ + *list_size = attr->list().field##_size(); \ + __VA_ARGS__; \ + break; \ + } + + LIST_CASE( + s, TF_ATTR_STRING, *total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { *total_size += attr->list().s(i).size(); }); + LIST_CASE(i, TF_ATTR_INT); + LIST_CASE(f, TF_ATTR_FLOAT); + LIST_CASE(b, TF_ATTR_BOOL); + LIST_CASE(type, TF_ATTR_TYPE); + LIST_CASE( + shape, TF_ATTR_SHAPE, *total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + *total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); + LIST_CASE(tensor, TF_ATTR_TENSOR); + LIST_CASE(tensor, TF_ATTR_FUNC); +#undef LIST_CASE + break; + + case tensorflow::AttrValue::kPlaceholder: + *list_size = -1; + *total_size = -1; + break; + + case tensorflow::AttrValue::kFunc: + *list_size = -1; + *total_size = -1; + break; + + case tensorflow::AttrValue::VALUE_NOT_SET: + status->status = + InvalidArgument("Attribute '", attr_name, "' has no value set"); + break; + } +} + +#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \ void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ const char* attr_name, \ c_type* val, TF_Status* status) { \ @@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { if (s.ok()) { \ *val = static_cast(v); \ } \ + } \ + void TF_OpKernelConstruction_GetAttr##func##List( \ + TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \ + int max_vals, TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + status->status = \ + tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \ + if (!status->status.ok()) return; \ + const auto len = std::min(max_vals, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + vals[i] = static_cast(attr->list().list_field(i)); \ + } \ } -DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) -DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type) +DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i) +DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i) +DEFINE_TF_GETATTR(Float, float, float, "float", f) +DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b) + +void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx, + const char* attr_name, char* value, + size_t max_length, + TF_Status* status) { + std::string v; + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (!status->status.ok()) return; + + if (max_length <= 0) { + return; + } + std::memcpy(value, v.data(), std::min(v.length(), max_length)); +} + +void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx, + const char* attr_name, + char** values, size_t* lengths, + int max_values, void* storage, + size_t storage_size, + TF_Status* status) { + std::vector v; + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (!status->status.ok()) return; + + const auto len = std::min(max_values, static_cast(v.size())); + char* p = static_cast(storage); + for (int i = 0; i < len; ++i) { + const std::string& s = v[i]; + values[i] = p; + lengths[i] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of strings"); + return; + } + memcpy(values[i], s.data(), s.size()); + p += s.size(); + } +} + +bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx, + const char* attr_name, TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + return cc_ctx->HasAttr(attr_name); +} TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { auto* cc_ctx = reinterpret_cast(ctx); diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 34848a1c92a..508d59b1223 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( // Returns the step ID of the given context. TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); +// Get the list_size and total_size of the attribute `attr_name` of `oper`. +// list_size - the length of the list. +// total_size - total size of the list. +// (1) If attr_type == TF_ATTR_STRING +// then total_size is the cumulative byte size +// of all the strings in the list. +// (3) If attr_type == TF_ATTR_SHAPE +// then total_size is the number of dimensions +// of the shape valued attribute, or -1 +// if its rank is unknown. +// (4) If attr_type == TF_ATTR_SHAPE +// then total_size is the cumulative number +// of dimensions of all shapes in the list. +// (5) Otherwise, total_size is undefined. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size, + int32_t* total_size, TF_Status* status); + // Interprets the named kernel construction attribute as a TF_DataType and // places it into *val. *status is set to TF_OK. // @@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_Status* status); +// Interprets the named kernel construction attribute as int64_t and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// int64, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as float and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// float, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat( + TF_OpKernelConstruction* ctx, const char* attr_name, float* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as bool and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// bool, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as string and +// places it into *val. `val` must +// point to an array of length at least `max_length` (ideally set to +// total_size from TF_OpKernelConstruction_GetAttrSize(ctx, +// attr_name, list_size, total_size)). *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// string, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString( + TF_OpKernelConstruction* ctx, const char* attr_name, char* val, + size_t max_length, TF_Status* status); + +// Interprets the named kernel construction attribute as a TF_DataType array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int32_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int64_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as float array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList( + TF_OpKernelConstruction* ctx, const char* attr_name, float* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as bool array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as string array and fills +// in `vals` and `lengths`, each of which must point to an array of length at +// least `max_values`. *status is set to TF_OK. The elements of values will +// point to addresses in `storage` which must be at least `storage_size` bytes +// in length. Ideally, max_values would be set to list_size and `storage` would +// be at least total_size, obtained from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList( + TF_OpKernelConstruction* ctx, const char* attr_name, char** vals, + size_t* lengths, int max_values, void* storage, size_t storage_size, + TF_Status* status); + +// Return true if the kernel construction has the attr_name +TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); + // Returns the unique operation name for this OpKernel. TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( TF_OpKernelConstruction* ctx); diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 49a168af076..4716bc3da55 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -161,6 +161,337 @@ TEST(TestKernel, TestRegisterKernelBuilder) { ASSERT_TRUE(delete_called); } +// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases. +// Registers two ops, each with a single attribute called 'Attr'. +// The attribute in one op will have a type 'type', the other +// will have list(type). +#define ATTR_TEST_REGISTER_OP(name, type) \ + REGISTER_OP("TestKernelAttr" #name) \ + .Attr("Attr: " #type) \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ + REGISTER_OP("TestKernelAttr" #name "List") \ + .Attr("Attr: list(" #type ")") \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape) +ATTR_TEST_REGISTER_OP(String, string); +ATTR_TEST_REGISTER_OP(Int, int); +ATTR_TEST_REGISTER_OP(Float, float); +ATTR_TEST_REGISTER_OP(Bool, bool); +ATTR_TEST_REGISTER_OP(Type, type); +#undef ATTR_TEST_REGISTER_OP + +// Helper macros for the TF_OpKernelConstruction_GetAttr* tests. +#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \ + do { \ + int32_t list_size, total_size; \ + TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \ + &total_size, status); \ + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \ + EXPECT_EQ(expected_list_size, list_size); \ + EXPECT_EQ(expected_total_size, total_size); \ + } while (0) + +typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*); +class TestKernelAttr : public ::testing::Test { + public: + TestKernelAttr() {} + ~TestKernelAttr() {} + + std::unique_ptr GetFakeKernelWithAttr(const char* op_name, + AttrValue v, Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_name("FakeNode"); + def.set_device("FakeDevice"); + (*def.mutable_attr())["Attr"] = v; + return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1, + status); + } + + void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name, + AttrValue& v) { + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc); + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder("FakeNode", builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + Status status; + std::unique_ptr kernel = + GetFakeKernelWithAttr(op_name, v, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + + ASSERT_TRUE(delete_called); + } +}; + +TEST_F(TestKernelAttr, String) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + std::unique_ptr val(new char[5]); + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ 5); + TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(), + /*max_length*/ 5, status); + + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("bunny", string(static_cast(val.get()), 5)); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_s("bunny"); + SetAttr(my_create_func, "TestKernelAttrString", v); +} + +TEST_F(TestKernelAttr, StringList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + std::vector list = {"bugs", "bunny", "duck"}; + int list_total_size = 0; + for (const auto& s : list) { + list_total_size += s.size(); + } + + TF_Status* status = TF_NewStatus(); + std::unique_ptr values(new char*[list.size()]); + std::unique_ptr lens(new size_t[list.size()]); + std::unique_ptr storage(new char[list_total_size]); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(), + /*expected_total_size*/ list_total_size); + TF_OpKernelConstruction_GetAttrStringList( + ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(), + list_total_size, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (size_t i = 0; i < list.size(); ++i) { + EXPECT_EQ(list[i].size(), lens[i]) << i; + EXPECT_EQ(list[i], string(static_cast(values[i]), lens[i])) + << i; + } + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + auto attr_in = gtl::ArraySlice({"bugs", "bunny", "duck"}); + SetAttrValue(attr_in, &v); + SetAttr(my_create_func, "TestKernelAttrStringList", v); +} + +TEST_F(TestKernelAttr, Int) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + int64_t val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(1234, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_i(1234); + SetAttr(my_create_func, "TestKernelAttrInt", v); +} + +TEST_F(TestKernelAttr, IntList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const int64_t list[] = {1, 2, 3, 4}; + const size_t list_size = TF_ARRAYSIZE(list); + int64_t values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + auto attr_in = gtl::ArraySlice({1, 2, 3, 4}); + SetAttrValue(attr_in, &v); + SetAttr(my_create_func, "TestKernelAttrIntList", v); +} + +TEST_F(TestKernelAttr, Float) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + float val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_FLOAT_EQ(2.718, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_f(2.718); + SetAttr(my_create_func, "TestKernelAttrFloat", v); +} + +TEST_F(TestKernelAttr, FloatList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const float list[] = {1.414, 2.718, 3.1415}; + const size_t list_size = TF_ARRAYSIZE(list); + float values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + auto attr_in = gtl::ArraySlice({1.414, 2.718, 3.1415}); + SetAttrValue(attr_in, &v); + SetAttr(my_create_func, "TestKernelAttrFloatList", v); +} + +TEST_F(TestKernelAttr, Bool) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + unsigned char val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(1, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_b(1); + SetAttr(my_create_func, "TestKernelAttrBool", v); +} + +TEST_F(TestKernelAttr, BoolList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const unsigned char list[] = {1, 0, 1, 0}; + const size_t list_size = TF_ARRAYSIZE(list); + unsigned char values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + auto attr_in = gtl::ArraySlice({1, 0, 1, 0}); + SetAttrValue(attr_in, &v); + SetAttr(my_create_func, "TestKernelAttrBoolList", v); +} + +TEST_F(TestKernelAttr, Type) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + TF_DataType val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_FLOAT, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_type(DT_FLOAT); + SetAttr(my_create_func, "TestKernelAttrType", v); +} + +TEST_F(TestKernelAttr, TypeList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128}; + const size_t list_size = TF_ARRAYSIZE(list); + TF_DataType values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + auto attr_in = + gtl::ArraySlice({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128}); + SetAttrValue(attr_in, &v); + SetAttr(my_create_func, "TestKernelAttrTypeList", v); +} +#undef EXPECT_TF_SIZE + class DummyDevice : public DeviceBase { public: explicit DummyDevice(Env* env) : DeviceBase(env) {} diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 26f8afd51a3..60fff052175 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NotOp); +MAP_HLO_TO_LHLO(OrOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReshapeOp); @@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); MAP_HLO_TO_LHLO(TransposeOp); +MAP_HLO_TO_LHLO(XorOp); #undef MAP_HLO_TO_LHLO diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index cf008059810..92032e4815e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, return nullptr; } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, @@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + } // namespace impl struct HloOpToStdScalarOp { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index ec3a192295b..22242122eb2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloReduceOpConverter, HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 4b15de39f00..69e729c124d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReduceConverter, ReshapeOpConverter, ReverseConverter, @@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1055,11 +1059,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 1788b2819c7..5ba5fd9b22c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -42,11 +42,11 @@ namespace { sep fn(SqrtOp) sep fn(TanhOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \ - fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \ - sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \ - sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ - sep fn(ShiftRightLogicalOp) sep fn(SubOp) +#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \ + fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \ + sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \ + sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ + sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index d7adf6ced5a..0c1ee243a04 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @and +func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32> + %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32> + %tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // CHECK-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) { // ----- +// CHECK-LABEL: func @or +func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32> + %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32> + %tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // CHECK-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- // CHECK-LABEL: func @remainder -func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { +func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) @@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x // ----- +// CHECK-LABEL: func @xor +func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>, + %result: memref<2x2xi32>) { + %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32> + %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32> + %tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // Dynamic shape binary element-wise operation. // CHECK-LABEL: func @add_dyn func @add_dyn(%lhs: tensor, %rhs: tensor) { diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 074934576a4..53fd205e9af 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @integer_or +func @integer_or(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: or + %0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @integer_xor +func @integer_xor(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: xor + %0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index a4339a1359e..0e0e2164014 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -509,7 +509,7 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor, return op.getOperation(); } auto op = builder.create(loc, value); - if (!tensor.quantization->min.empty()) { + if (tensor.quantization && !tensor.quantization->min.empty()) { if (auto stats_op = ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) { return stats_op; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 61702c9ff47..a2c40453c91 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1977,6 +1977,7 @@ cc_library( hdrs = ["utils/bridge_logger.h"], deps = [ ":dump_mlir_util", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 91f8d0ebd9c..90ff30c6653 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -5028,6 +5028,26 @@ Table initializer that takes two tensors for keys and values respectively. TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>; } +def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> { + let summary = "Adds v into specified rows of x."; + + let description = [{ +Computes y = x; y[i, :] += v; return y. + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_Int32Tensor:$i, + TF_Tensor:$v + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { let summary = "Updates specified rows 'i' with values 'v'."; @@ -5374,6 +5394,37 @@ def TF_IteratorV2Op : TF_Op<"IteratorV2", []> { ); } +def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> { + let summary = "Computes the Kth order statistic of a data set. The current"; + + let description = [{ +implementation uses a binary search requiring exactly 32 passes over +the input data. The running time is linear with respect to input +size. The median-of-medians algorithm is probably faster, but is +difficult to implement efficiently in XLA. The implementation imposes +a total ordering on floats. The ordering is consistent with the usual +partial order. Positive NaNs are greater than positive +infinity. Negative NaNs are less than negative infinity. NaNs with +distinct payloads are treated as distinct. Subnormal numbers are +preserved (not flushed to zero). Positive infinity is greater than all +numbers. Negative infinity is less than all numbers. Positive is +greater than negative zero. There are less than k values greater than +the kth order statistic. There are at least k values greater than or +equal to the Kth order statistic. The semantics are not the same as +top_k_unique. + }]; + + let arguments = (ins + TF_Float32Tensor:$input, + + I64Attr:$k + ); + + let results = (outs + TF_Float32Tensor:$output + ); +} + def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> { let summary = "L2 Loss."; @@ -6505,6 +6556,27 @@ iterator in `iterator` to the first element of `dataset`. let results = (outs); } +def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> { + let summary = [{ +Make all elements in the non-Batch dimension unique, but \"close\" to + }]; + + let description = [{ +their initial value. Never returns a sub-normal number. Never returns +zero. The sign of each input element is always identical to the sign +of the corresponding output element. Behavior for infinite elements is +undefined. Behavior for subnormal elements is undefined. + }]; + + let arguments = (ins + TF_Float32Tensor:$input + ); + + let results = (outs + TF_Float32Tensor:$output + ); +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -15234,6 +15306,36 @@ array([[1, 2, 3, 1, 2, 3], let hasFolder = 1; } +def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> { + let summary = "Returns the TopK unique values in the array in sorted order."; + + let description = [{ +The running time is proportional to the product of K and the input +size. Sorting the whole array is more efficient for sufficiently large +values of K. The median-of-medians algorithm is probably faster, but +difficult to implement efficiently in XLA. If there are fewer than K +unique numbers (not NANs), the results are padded with negative +infinity. NaNs are never returned. Subnormal numbers are flushed to +zero. If an element appears at multiple indices, the highest index is +returned. If a TopK element never appears in the input due to padding +values, the indices are padded with negative one. If a padding value +appears in the input and padding is needed, the highest index of the +padding value will be returned. The semantics are not the same as +kth_order_statistic. + }]; + + let arguments = (ins + TF_Float32Tensor:$input, + + I64Attr:$k + ); + + let results = (outs + TF_Float32Tensor:$topk, + TF_Int32Tensor:$topk_indices + ); +} + def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { let summary = [{ Finds values and indices of the `k` largest elements for the last dimension. @@ -15269,6 +15371,29 @@ If two elements are equal, the lower-index element appears first. let verifier = [{ return Verify(*this); }]; } +def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> { + let summary = "Returns the TopK values in the array in sorted order."; + + let description = [{ +This is a combination of MakeUnique and TopKUnique. The returned top-K will +have its lower bits replaced by iota, thus it will be close to the original +value but not exactly the same. The running time is proportional to the product +of K and the input size. NaNs are never returned. Subnormal numbers are flushed +to zero. + }]; + + let arguments = (ins + TF_Float32Tensor:$input, + + I64Attr:$k + ); + + let results = (outs + TF_Float32Tensor:$topk, + TF_Int32Tensor:$topk_indices + ); +} + def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> { let summary = "Shuffle dimensions of x according to a permutation."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 95858b366fd..096b822afbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -532,7 +532,11 @@ OpFoldResult RankOp::fold(ArrayRef operands) { auto ranked_type = type.dyn_cast(); if (!ranked_type) return {}; - auto output_type = getType().cast(); + // DenseIntElementsAttr::get requires the output type be ranked with static + // shape. + auto output_type = getType().dyn_cast(); + if (!output_type || !output_type.hasStaticShape()) return {}; + int32_t rank = ranked_type.getRank(); return DenseIntElementsAttr::get(output_type, rank); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 5f3e1e1759b..9b7993d97d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -904,6 +904,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor { return %0 : tensor } +// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput +func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> { + // Regression test to make sure we don't crash in this case. + %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput +func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor { + // Regression test to make sure we don't crash in this case. + %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor + return %0 : tensor +} + // CHECK-LABEL: @foldFill func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex>) { %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 96c2927c6cc..fbdf0dbf9b7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1627,8 +1627,10 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) { return success(); } int64_t producer = producer_or.ValueOrDie(); + // TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no + // longer needed. ShapeInference context(producer, module.getContext(), - /*propagate_caller_callee_constants=*/true); + /*propagate_caller_callee_constants=*/false); if (auto main = module.lookupSymbol("main")) context.enqueue(main); for (auto func : module.getOps()) context.enqueue(func); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index d7b511094d3..3565f58ea17 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include + +#include "absl/strings/str_split.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Operation.h" // from @llvm-project @@ -23,17 +26,30 @@ limitations under the License. namespace tensorflow { +// Counter is used as a prefix for filenames. +static std::atomic log_counter(0); + BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope, bool print_after_only_on_change) : mlir::PassManager::IRPrinterConfig(print_module_scope, - print_after_only_on_change) {} + print_after_only_on_change) { + const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS"); + if (log_pass_patterns) { + log_pass_patterns_ = + absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace()); + } +} -// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`. +// Logs op to file with name of format +// `_mlir_bridge__.mlir`. inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, mlir::Pass* pass, mlir::Operation* op, llvm::StringRef file_suffix) { - std::string name = - llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str(); + std::string pass_name = pass->getName().str(); + + // Add 4-digit counter as prefix so the order of the passes is obvious. + std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++, + pass_name, file_suffix); std::unique_ptr os; std::string filepath; @@ -44,13 +60,30 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, mlir::Operation* operation, PrintCallbackFn print_callback) { - Log(print_callback, pass, operation, "before"); + if (should_print(pass)) Log(print_callback, pass, operation, "before"); } void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass, mlir::Operation* operation, PrintCallbackFn print_callback) { - Log(print_callback, pass, operation, "after"); + if (should_print(pass)) Log(print_callback, pass, operation, "after"); +} + +bool BridgeLoggerConfig::should_print(mlir::Pass* pass) { + if (log_pass_patterns_.empty()) return true; + + std::string pass_name = pass->getName().str(); + for (const auto& pattern : log_pass_patterns_) { + if (pass_name.find(pattern) != std::string::npos) { + // pattern matches pass + return true; + } + } + // no pattern matches pass + VLOG(2) << "Not logging pass " << pass_name + << " because it does not match any pattern in " + "MLIR_BRIDGE_LOG_PASS_PATTERNS"; + return false; } void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index eaf3a7c2598..c7cd22bd479 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -23,7 +23,11 @@ limitations under the License. namespace tensorflow { // Logger for logging/dumping MLIR modules before and after passes in bridge -// targeting TPUs. +// targeting TPUs. The passes being logged can be restricted via environment +// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma- +// separated list of strings, and only passes whose name contains any of those +// strings as a substring are logged (no regex support). If +// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged. class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { public: explicit BridgeLoggerConfig(bool print_module_scope = false, @@ -42,6 +46,14 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { // with the stream to dump into. void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation, PrintCallbackFn print_callback) override; + + private: + bool should_print(mlir::Pass *pass); + + // Only print passes that match any of these patterns. A pass matches a + // pattern if its name contains the pattern as a substring. If + // `log_pass_patterns_` is empty, print all passes. + std::vector log_pass_patterns_; }; // Logger for logging/dumping pass pipeline timings after completion. diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 36aa31b0d34..9572e78efee 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -516,6 +516,11 @@ class ConvertToHloModule { // // TODO(hinsu): Check for dynamic shapes and exit instead of crashing. LogicalResult Run() { + auto main = module_.lookupSymbol("main"); + if (!main) + return module_.emitError( + "conversion requires module with `main` function"); + for (auto func : module_.getOps()) { if (func.empty()) continue; if (failed(RunOnFunction(func))) return failure(); @@ -539,8 +544,11 @@ class ConvertToHloModule { xla::XlaComputation* result); ::xla::HloModuleProto ConsumeMainProto() { - return lowered_computation_[module_.lookupSymbol("main")] - .proto(); + auto main = module_.lookupSymbol("main"); + // This is an invariant check as Run returns failure if there is no main + // function and so the main proto shouldn't be consumed in that case. + CHECK(main) << "requires module to have main function"; // Crash Ok. + return lowered_computation_[main].proto(); } // Lower function call to HLO call instruction diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 04dc3c88cb2..16f5c96e5e3 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -174,6 +174,13 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { return %0: tensor<4xi32> } +// CHECK-LABEL: func @bitwise_xor +func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: mhlo.xor + %0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + // CHECK-LABEL: func @bitwise_and func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: mhlo.and diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index c44715bb4db..940def69da8 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor<*xi32>) -> tensor<*xi32> { +func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NOT: tf.FloorDiv + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0: tensor<*xf32> +} + +// CHECK-LABEL: func @floordiv_int +func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { // CHECK: tf.FloorDiv %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> return %0: tensor<*xi32> @@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor<*xi32>) -> tensor<*xi32> { - // CHECK: tf.FloorMod + // CHECK-NOT: tf.FloorMod %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> return %0: tensor<*xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir new file mode 100644 index 00000000000..a2647d2c29f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/missing_main.mlir @@ -0,0 +1,7 @@ +// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s + +// CHECK: conversion requires module with `main` +func @non_main() { + %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> + return +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index b363e2a8ea6..f897351e6c2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; // Performs a substitution of FloorDiv, pseudo code below: // // return floor(div(x, y)) -def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; @@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. -def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), +def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), (HLO_SelectOp (HLOClient_BroadcastAndOp (HLOClient_BroadcastCompareOp @@ -193,14 +193,15 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), //===----------------------------------------------------------------------===// class DirectLogicalBinaryPat - : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $l)]>; foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], [TF_LogicalOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseAndOp, HLOClient_BroadcastAndOp], [TF_BitwiseOrOp, HLOClient_BroadcastOrOp], - [TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in + [TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 6cd4c277705..3f3f961595d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -154,9 +154,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -170,6 +172,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -248,6 +251,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index fc61d4b52f6..baa7ab0db3c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -121,7 +121,7 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // Run all HLO passes to produce an optimized module. auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( std::move(hlo_module), backend->default_stream_executor(), - backend->memory_allocator(), optimize_xla_hlo); + optimize_xla_hlo, {backend->memory_allocator()}); TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), "running XLA pass pipeline"); std::unique_ptr optimized_hlo_module = diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index d3f5dd3e662..c55f5750da7 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -115,6 +115,16 @@ class ExecutableBuildOptions { return *this; } + // Thread pool for parallel compilation. + tensorflow::thread::ThreadPool* compile_thread_pool() const { + return compile_thread_pool_; + } + ExecutableBuildOptions& set_run_backend_only( + tensorflow::thread::ThreadPool* compile_thread_pool) { + compile_thread_pool_ = compile_thread_pool; + return *this; + } + private: int device_ordinal_ = -1; Shape result_layout_; @@ -128,6 +138,7 @@ class ExecutableBuildOptions { absl::optional device_assignment_; bool alias_passthrough_params_ = false; bool run_backend_only_ = false; + tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index d3b885643fc..c23b40ab6cd 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3347,6 +3347,8 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { // contant False if dimension is static. // - Reduce: Convert to reduce or. // - Constant: Convert to constant False. + // - Reshape, slice, transpose, pad: + // Convert into predicate type with same opcode. // - Other ops: Not supported. // Create the instruction for the new handle. TF_ASSIGN_OR_RETURN(HloOpcode opcode, @@ -3449,6 +3451,7 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { case HloOpcode::kBroadcast: case HloOpcode::kConcatenate: case HloOpcode::kReshape: + case HloOpcode::kPad: break; case HloOpcode::kGetDimensionSize: { int64 dimension = instr_proto->dimensions(0); diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index 9b0f060f392..5241efbd2de 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), - /*device_kind=*/kCpuPlatformName) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -40,7 +40,7 @@ StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutorConfig config; config.ordinal = i; @@ -57,11 +57,11 @@ StatusOr> GetCpuClient(bool asynchronous) { devices.push_back(std::move(device)); } - return std::make_unique( + return std::unique_ptr(std::make_unique( kCpuName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, - /*gpu_run_options=*/nullptr); + /*gpu_run_options=*/nullptr)); } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index 1036d8fedbb..0aab55e6493 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class CpuDevice : public PjRtDevice { +class CpuDevice : public PjRtStreamExecutorDevice { public: CpuDevice(int id, std::unique_ptr local_device_state); }; diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.cc b/tensorflow/compiler/xla/pjrt/gpu_device.cc index 26f38c2cbc4..302c7734d73 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_device.cc @@ -35,9 +35,9 @@ namespace xla { namespace { // A custom PjRtClient that overrides the device assignment method. -class GpuClient : public xla::PjRtClient { +class GpuClient : public xla::PjRtStreamExecutorClient { public: - using xla::PjRtClient::PjRtClient; + using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient; xla::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -55,7 +55,8 @@ xla::StatusOr GpuClient::GetDefaultDeviceAssignment( return assignment; } // Fallback to default global device assignment if we can't run locally. - return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); + return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas, + num_partitions); } // Builds an xla::LocalClient for the GPU platform. @@ -225,9 +226,9 @@ StatusOr NcclIdStore::GetNcclUniqueId( return result.first->second; } -std::vector> BuildLocalDevices( +std::vector> BuildLocalDevices( std::vector> local_device_states) { - std::vector> devices; + std::vector> devices; for (auto& local_device : local_device_states) { int device_ordinal = local_device->device_ordinal(); const se::DeviceDescription& description = @@ -243,7 +244,7 @@ std::vector> BuildLocalDevices( Status BuildDistributedDevices( std::vector> local_device_states, std::shared_ptr distributed_client, int node_id, - std::vector>* devices, + std::vector>* devices, gpu::GpuExecutableRunOptions* gpu_executable_run_options) { LocalTopologyProto local_topology; local_topology.set_node_id(node_id); @@ -306,8 +307,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : PjRtDevice(id, std::move(local_device_state), std::move(device_kind), - node_id) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + std::move(device_kind), node_id) {} StatusOr> GetGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, @@ -322,7 +323,7 @@ StatusOr> GetGpuClient( auto host_memory_allocator = GetGpuHostAllocator(local_device_states.front()->executor()); - std::vector> devices; + std::vector> devices; auto gpu_run_options = absl::make_unique(); if (distributed_client) { TF_RETURN_IF_ERROR(BuildDistributedDevices( diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.h b/tensorflow/compiler/xla/pjrt/gpu_device.h index 7ea85db0401..142a263d959 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/gpu_device.h @@ -25,7 +25,7 @@ limitations under the License. namespace xla { -class GpuDevice : public PjRtDevice { +class GpuDevice : public PjRtStreamExecutorDevice { public: GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 2819cabf258..3b3daba5906 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), - /*device_kind=*/kInterpreterPlatformName) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -41,7 +41,7 @@ StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; se::StreamExecutor* executor = client->backend().stream_executor(0).ValueOrDie(); auto device_state = absl::make_unique( @@ -51,11 +51,11 @@ StatusOr> GetInterpreterClient() { absl::make_unique(0, std::move(device_state)); devices.push_back(std::move(device)); - return std::make_unique( + return std::unique_ptr(std::make_unique( "interpreter", client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, - /*gpu_run_options=*/nullptr); + /*gpu_run_options=*/nullptr)); } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index 4038d8dbf11..a23ddcb5bb9 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class InterpreterDevice : public PjRtDevice { +class InterpreterDevice : public PjRtStreamExecutorDevice { public: InterpreterDevice(int id, std::unique_ptr local_device_state); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 191b3467de6..92e8d522010 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -114,21 +114,22 @@ limitations under the License. namespace xla { -PjRtPlatformId PjRtDevice::platform_id() const { +PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const { return client_->platform_id(); } -const std::string& PjRtDevice::platform_name() const { +const std::string& PjRtStreamExecutorDevice::platform_name() const { return client_->platform_name(); } -StatusOr PjRtDevice::GetLocalDeviceState() const { +StatusOr PjRtStreamExecutorDevice::GetLocalDeviceState() + const { if (local_device_state_) { return local_device_state_.get(); } return InvalidArgument("Device %s is not a local device.", DebugString()); } -std::string PjRtDevice::DebugString() const { +std::string PjRtStreamExecutorDevice::DebugString() const { return absl::StrCat(platform_name(), ":", id()); } @@ -153,14 +154,15 @@ StatusOr DevicesToDeviceAssignment( devices[replica].size(), replica, devices[0].size()); } for (int partition = 0; partition < devices[replica].size(); ++partition) { - if (devices[0][0]->platform_id() != - devices[replica][partition]->platform_id()) { + if (devices[0][0]->client()->platform_id() != + devices[replica][partition]->client()->platform_id()) { return InvalidArgument( "Device assignment passed to Compile() must have devices of a " "single kind, got %s for replica 0 partition 0 and %s for replica " "%d partition %d.", - devices[0][0]->platform_name(), - devices[replica][partition]->platform_name(), replica, partition); + devices[0][0]->client()->platform_name(), + devices[replica][partition]->client()->platform_name(), replica, + partition); } xla_assignment(replica, partition) = devices[replica][partition]->id(); } @@ -182,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator { } }; -PjRtClient::PjRtClient( +PjRtStreamExecutorClient::PjRtStreamExecutorClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -193,7 +195,7 @@ PjRtClient::PjRtClient( platform_name_(std::move(platform_name)), client_(client), host_memory_allocator_(std::move(host_memory_allocator)), - devices_(std::move(devices)), + owned_devices_(std::move(devices)), host_id_(host_id), owned_allocator_(std::move(allocator)), should_stage_host_to_device_transfers_( @@ -211,12 +213,14 @@ PjRtClient::PjRtClient( host_memory_allocator_ = std::make_unique(); } - for (const std::unique_ptr& device : devices_) { + for (const std::unique_ptr& device : + owned_devices_) { + devices_.push_back(device.get()); CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); - if (device->IsLocalDevice()) { - int idx = device->local_device_id(); + if (device->IsAddressable()) { + int idx = device->local_hardware_id(); if (idx >= local_devices_.size()) { local_devices_.resize(idx + 1); } @@ -230,13 +234,14 @@ PjRtClient::PjRtClient( } } -StatusOr PjRtClient::GetDefaultDeviceAssignment( +StatusOr PjRtStreamExecutorClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { return client_->backend().computation_placer()->AssignDevices(num_replicas, num_partitions); } -std::unique_ptr PjRtClient::GetHloCostAnalysis() { +std::unique_ptr +PjRtStreamExecutorClient::GetHloCostAnalysis() { return absl::make_unique( client_->backend().compiler()->ShapeSizeBytesFunction()); } @@ -346,12 +351,13 @@ StatusOr> AllocateDestinationBuffer( return InvalidArgument("Can't make a buffer from an empty tuple"); } + auto* se_client = tensorflow::down_cast(client); TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer dst_buffer, - transfer_manager->AllocateScopedShapedBuffer( - on_host_shape, client->allocator(), local_device->device_ordinal())); + se_client->client()->backend().transfer_manager(); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer, + transfer_manager->AllocateScopedShapedBuffer( + on_host_shape, se_client->allocator(), + local_device->device_ordinal())); if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { if (copy_stream == nullptr) { @@ -543,18 +549,21 @@ void PjRtBuffer::ScopedHold::AddToInput( bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; } -StatusOr> PjRtClient::BufferFromHostBuffer( +StatusOr> +PjRtStreamExecutorClient::BufferFromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, std::shared_ptr buffer_reference, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer"); - VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString() - << " device: " << device->DebugString(); + tensorflow::profiler::TraceMe traceme( + "PjRtStreamExecutorClient::BufferFromHostBuffer"); + VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: " + << shape.ToString() << " device: " << device->DebugString(); if (shape.IsTuple()) { return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); } TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); int64 size = ShapeUtil::ByteSizeOf(shape); TransferManager* transfer_manager = client()->backend().transfer_manager(); @@ -708,20 +717,23 @@ StatusOr> PjRtClient::BufferFromHostBuffer( return py_buffer; } -StatusOr> PjRtClient::CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device) { +StatusOr> +PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape, + PjRtDevice* device) { return CreateUninitializedBuffer(shape, device, nullptr); } -StatusOr> PjRtClient::CreateUninitializedBuffer( +StatusOr> +PjRtStreamExecutorClient::CreateUninitializedBuffer( const Shape& shape, PjRtDevice* device, std::shared_ptr definition_event) { tensorflow::profiler::TraceMe traceme( - "PjRtClient::CreateUninitializedBuffer"); - VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " + "PjRtStreamExecutorClient::CreateUninitializedBuffer"); + VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, @@ -733,13 +745,16 @@ StatusOr> PjRtClient::CreateUninitializedBuffer( definition_event); } -StatusOr> PjRtClient::BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral"); - VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " +StatusOr> +PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, + PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme( + "PjRtStreamExecutorClient::BufferFromHostLiteral"); + VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN( @@ -792,7 +807,7 @@ StatusOr> PjRtClient::BufferFromHostLiteral( return py_buffer; } -void PjRtClient::MakeCrossHostReceiveBuffers( +void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers( absl::Span shapes, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { @@ -801,7 +816,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers( return; } - auto local_device_or = device->GetLocalDeviceState(); + auto local_device_or = + tensorflow::down_cast(device) + ->GetLocalDeviceState(); if (!local_device_or.ok()) { notifier(local_device_or.status()); return; @@ -828,27 +845,30 @@ void PjRtClient::MakeCrossHostReceiveBuffers( } // Transfer the given literal to the infeed queue of the given local device. -Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const { +Status PjRtStreamExecutorDevice::TransferToInfeed( + const LiteralSlice& literal) const { // Only support infeed to local device. TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferToInfeedLocal( literal, local_device->device_ordinal()); } -StatusOr PjRtDevice::TransferFromOutfeed(const Shape& shape) const { +StatusOr PjRtStreamExecutorDevice::TransferFromOutfeed( + const Shape& shape) const { TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferFromOutfeedLocal( shape, local_device->device_ordinal()); } -StatusOr PjRtClient::LookupLocalDevice(int local_device_id) const { +StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( + int local_hardware_id) const { for (auto* device : local_devices_) { - if (local_device_id == device->local_device_id()) { + if (local_hardware_id == device->local_hardware_id()) { return device; } } - return InvalidArgument("No matching device found for local_device_id %d", - local_device_id); + return InvalidArgument("No matching device found for local_hardware_id %d", + local_hardware_id); } PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, @@ -873,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() { } int64 PjRtBuffer::OnDeviceSizeInBytes() const { - return client_->client() + return tensorflow::down_cast(client_) + ->client() ->backend() .transfer_manager() ->GetByteSizeRequirement(on_device_shape_); @@ -919,7 +940,9 @@ StatusOr> PjRtBuffer::Release( // the final set of usage events. events = device_buffer->LockUseAndTransferUsageEvents(); } - LocalDeviceState* local_device_state = device_->local_device_state(); + LocalDeviceState* local_device_state = + tensorflow::down_cast(device_) + ->local_device_state(); if (wait_for_operations_to_complete) { // Block the host until all usage events have completed. Usage events // dominate definition events, so this also waits for the buffer to be @@ -1080,7 +1103,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, } ScopedHold device_buffer(this, ScopedHold::kUsage); std::shared_ptr host_value; - LocalDeviceState* local_device = device_->local_device_state(); + LocalDeviceState* local_device = + tensorflow::down_cast(device_) + ->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); const xla::Layout& host_layout = layout.has_value() ? layout.value() : on_host_shape_.layout(); @@ -1122,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, host_value->value = std::make_shared(host_shape); ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(host_shape, on_device_shape_); - client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( - stream, shaped_buffer, host_value->value.get(), - [host_value](Status done_status) { - host_value->status = done_status; - host_value->ready.Notify(); - }); + tensorflow::down_cast(client_) + ->client() + ->backend() + .transfer_manager() + ->TransferLiteralFromDevice(stream, shaped_buffer, + host_value->value.get(), + [host_value](Status done_status) { + host_value->status = done_status; + host_value->ready.Notify(); + }); auto usage_event = std::make_shared(); StatusOr event_or = @@ -1156,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, StatusOr> PjRtBuffer::ToLiteral( const bool discard_cached_copy, absl::optional layout) { - tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral"); + tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral"); TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { @@ -1241,8 +1270,9 @@ PjRtBuffer::CopyToDeviceHelper( // StallStreamOnError only makes sure the destination device is ok, so // make sure that the src buffer remains valid until after any transfers // have completed. - device_->local_device_state()->ThenRelease(transfer_stream, - src_device_buffer); + tensorflow::down_cast(device_) + ->local_device_state() + ->ThenRelease(transfer_stream, src_device_buffer); } return copy_event_or.status(); } @@ -1265,14 +1295,20 @@ StatusOr> PjRtBuffer::CopyToDevice( TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteral()); return dst_device->client()->BufferFromHostBuffer( literal->untyped_data(), literal->shape(), - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); + PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr, + dst_device); } - TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, - dst_device->GetLocalDeviceState()); + TF_ASSIGN_OR_RETURN( + LocalDeviceState * dst_local_device, + tensorflow::down_cast(dst_device) + ->GetLocalDeviceState()); LocalDeviceState* transfer_local_device = - client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() - : dst_local_device; + tensorflow::down_cast(client_) + ->EnqueueD2DTransfersOnSrcStream() + ? tensorflow::down_cast(device_) + ->local_device_state() + : dst_local_device; CHECK_EQ(dst_local_device->allocation_model(), transfer_local_device->allocation_model()); @@ -1310,7 +1346,9 @@ StatusOr> PjRtBuffer::CopyToDevice( // alternative is to ensure, before freeing the buffer, that the compute // stream is synchronized past the transfer, but it seems better to hold onto // the buffer too long than to stall the compute stream. - RecordUsage(std::move(src_device_buffer), device_->local_device_state(), + RecordUsage(std::move(src_device_buffer), + tensorflow::down_cast(device_) + ->local_device_state(), transfer_local_device, event, transfer_stream, /*prefer_to_retain_reference=*/true); @@ -1318,7 +1356,8 @@ StatusOr> PjRtBuffer::CopyToDevice( } Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) { - return client_->CopyToRemoteDevice(this, serialized_descriptor); + return tensorflow::down_cast(client_) + ->CopyToRemoteDevice(this, serialized_descriptor); } Status PjRtBuffer::BlockHostUntilReady() { @@ -1332,7 +1371,9 @@ Status PjRtBuffer::BlockHostUntilReady() { } device_buffer = device_buffer_; } - LocalDeviceState* local_device_state = device_->local_device_state(); + LocalDeviceState* local_device_state = + tensorflow::down_cast(device_) + ->local_device_state(); std::unique_ptr stream; for (auto& event : device_buffer->definition_events()) { if (!event->IsComplete()) { @@ -1378,9 +1419,13 @@ StatusOr MakeTupleHelper( Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes); - se::DeviceMemoryAllocator* allocator = client->allocator(); + se::DeviceMemoryAllocator* allocator = + tensorflow::down_cast(client)->allocator(); TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + tensorflow::down_cast(client) + ->client() + ->backend() + .transfer_manager(); se::Stream* stream = local_device->host_to_device_stream(); TF_ASSIGN_OR_RETURN( se::OwningDeviceMemory root_table_memory, @@ -1444,14 +1489,6 @@ std::unique_ptr OutputBufferHelper( /*prefer_to_retain_reference=*/false); return pjrt_buffer; } - -static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) { - auto it = client.id_to_device().find(device_id); - CHECK(it != client.id_to_device().end()) - << "Unknown device id: " << device_id; - return it->second; -} - } // namespace PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable( @@ -1459,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable( bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector addressable_device_logical_ids, - std::vector addressable_devices, PjRtClient* client) + std::vector addressable_devices, + PjRtStreamExecutorClient* client) : client_(client), device_assignment_(std::move(device_assignment)), parameter_is_tupled_arguments_(parameter_is_tupled_arguments), @@ -1482,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable( VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n" << device_assignment_->ToString(); CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString(); - CHECK_LE(addressable_devices_.size(), client_->local_device_count()) + CHECK_LE(addressable_devices_.size(), client_->addressable_device_count()) << "Inconsistent local device count."; num_partitions = device_assignment_->computation_count(); } @@ -1584,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents( absl::Span device_buffers, absl::flat_hash_set& events) const { std::vector execution_inputs; - LocalDeviceState* device_state = &client_->device_state(device_ordinal); + LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); // Lift tuple_handle outside the conditional so that the event it returns is // not destroyed until after the loop below that waits on events. absl::optional tuple_handle; @@ -1607,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents( execution_input.MutableBuffers()->begin(); ShapeTree::iterator iterator_end = execution_input.MutableBuffers()->end(); - device_buffers[i].AddToInput(&input_iterator, iterator_end, - &execution_input, client_->allocator()); + device_buffers[i].AddToInput( + &input_iterator, iterator_end, &execution_input, + tensorflow::down_cast(client_) + ->allocator()); CHECK(input_iterator == iterator_end); } } @@ -1628,8 +1668,10 @@ StatusOr PjRtStreamExecutorExecutable::EnqueueExecution( int executable_idx, const RunId& run_id, const ExecuteOptions& options, PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const { - int device_ordinal = device->local_device_state()->device_ordinal(); - LocalDeviceState* device_state = &client_->device_state(device_ordinal); + int device_ordinal = tensorflow::down_cast(device) + ->local_device_state() + ->device_ordinal(); + LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); tensorflow::profiler::TraceMeConsumer activity( "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, run_id.ToInt()); @@ -1765,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers( std::shared_ptr definition_event, PjRtDevice* device) const { std::vector> outputs; - LocalDeviceState* device_state = &client_->device_state(device_ordinal); + LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); outputs.reserve(tuple_count); @@ -1802,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper( if (device == nullptr) { CHECK(device_assignment_ != nullptr); const int device_id = (*device_assignment_)(replica, partition); - device = LookupDevice(*client_, device_id); + TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id)); device_assignment = device_assignment_; } else { CHECK(device_assignment_ == nullptr); @@ -1814,7 +1856,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper( } CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_state()->device_ordinal(); + int device_ordinal = tensorflow::down_cast(device) + ->local_device_state() + ->device_ordinal(); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; @@ -1836,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper( ScopedShapedBuffer result_buffer = result_buffer_or_status.ConsumeValueOrDie(); - LocalDeviceState* device_state = &client_->device_state(device_ordinal); + LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); se::Stream* stream = device_state->compute_stream(); StatusOr event_or = device_state->event_pool().ThenAllocateAndRecordEvent(stream); @@ -1922,7 +1966,9 @@ PjRtStreamExecutorExecutable::Execute( const int replica = addressable_device_logical_ids_[i].replica; const int partition = addressable_device_logical_ids_[i].partition; PjRtDevice* device = addressable_devices_[i]; - const LocalDeviceState& device_state = *device->local_device_state(); + const LocalDeviceState& device_state = + *tensorflow::down_cast(device) + ->local_device_state(); device_state.execute_thread()->Schedule([&, replica, partition, i] { results[i] = ExecuteHelper(argument_handles[i], replica, partition, run_id, options); @@ -2131,9 +2177,9 @@ StatusOr, Shape>> GetShardedProgramShapes( } // namespace -StatusOr> PjRtClient::Compile( +StatusOr> PjRtStreamExecutorClient::Compile( const XlaComputation& computation, CompileOptions options) { - tensorflow::profiler::TraceMe traceme("PjRtClient::Compile"); + tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile"); ExecutableBuildOptions& build_options = options.executable_build_options; if (!build_options.device_allocator()) { @@ -2153,14 +2199,15 @@ StatusOr> PjRtClient::Compile( num_partitions = 1; } else { if (!build_options.has_device_assignment()) { - VLOG(2) << "PjRtClient::Compile using default device_assignment."; + VLOG(2) << "PjRtStreamExecutorClient::Compile using default " + "device_assignment."; TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, GetDefaultDeviceAssignment(build_options.num_replicas(), build_options.num_partitions())); build_options.set_device_assignment(device_assignment); } - VLOG(2) << "PjRtClient::Compile device_assignment:\n" + VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n" << build_options.device_assignment().ToString(); num_replicas = build_options.device_assignment().replica_count(); num_partitions = build_options.device_assignment().computation_count(); @@ -2234,7 +2281,7 @@ StatusOr> PjRtClient::Compile( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); - PjRtDevice* device = LookupDevice(*this, device_id); + TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id)); if (device->host_id() != host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; @@ -2254,7 +2301,7 @@ StatusOr> PjRtClient::Compile( if (build_options.device_ordinal() < 0) { build_options.set_device_ordinal( - addressable_devices.front()->local_device_state()->device_ordinal()); + addressable_devices.front()->local_hardware_id()); } } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index b32d2889fc8..cc608c0d8ff 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -67,16 +68,56 @@ class PjRtClient; class PjRtDevice { public: - explicit PjRtDevice(int id, - std::unique_ptr local_device_state, - std::string device_kind, int host_id = 0) + virtual ~PjRtDevice() {} + + // Return the client that owns this device. + virtual PjRtClient* client() const = 0; + + // Whether client can issue command to this device. + virtual bool IsAddressable() const = 0; + + // The ID of this device. IDs are unique among devices of this type + // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all + // hosts' devices. This is the ID that should be used in a DeviceAssignment. + virtual int id() const = 0; + + // The task ID of this device according to TpuTopology. This is not the same + // as PjRtClient::host_id() in a multi-task setting, where each client can see + // devices from all tasks, but only a subset of them are addressable and have + // the same task_id as the client. + virtual int host_id() const = 0; + + // Opaque hardware ID, e.g., the CUDA device number, useful for identifying + // which GPU when interacting with non-JAX code. In general, not guaranteed to + // be dense, and -1 if undefined. + virtual int local_hardware_id() const = 0; + + // A vendor-dependent string that uniquely identifies the kind of device, + // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are + // compatible compilation. + virtual const std::string& device_kind() const = 0; + + virtual std::string DebugString() const = 0; + + // Transfer the given literal to the infeed queue. + virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0; + + // Transfer and return a value of the given shape from the outfeed queue. + virtual StatusOr TransferFromOutfeed(const Shape& shape) const = 0; +}; + +class PjRtStreamExecutorDevice : public PjRtDevice { + public: + explicit PjRtStreamExecutorDevice( + int id, std::unique_ptr local_device_state, + std::string device_kind, int host_id = 0) : id_(id), - local_device_id_( + device_ordinal_( local_device_state ? local_device_state->device_ordinal() : -1), local_device_state_(std::move(local_device_state)), host_id_(host_id), device_kind_(std::move(device_kind)) {} - virtual ~PjRtDevice() {} + ~PjRtStreamExecutorDevice() override {} // Must set client exactly once. void SetClient(PjRtClient* client) { @@ -84,14 +125,25 @@ class PjRtDevice { client_ = client; } + // Task ID. This is always 0 on single-task setup. + int host_id() const override { return host_id_; } + + // Return `platform_id` from client. + PjRtPlatformId platform_id() const; + + // Return `platform_name` from client. + const std::string& platform_name() const; + + PjRtClient* client() const override { return client_; } + // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // hosts' devices. This is the ID that should be used in a DeviceAssignment. - int id() const { return id_; } + int id() const override { return id_; } - bool IsLocalDevice() const { return local_device_id_ != -1; } + bool IsAddressable() const override { return device_ordinal_ != -1; } - int local_device_id() const { return local_device_id_; } + int local_hardware_id() const override { return device_ordinal_; } // If this is a device local to this host, returns a LocalDeviceState object // that can be used to manipulate the device. Returns nullptr if the device is @@ -105,32 +157,21 @@ class PjRtDevice { // is not local to this host. StatusOr GetLocalDeviceState() const; - // The ID of this device's host. This is always 0 on single-host platforms. - int host_id() const { return host_id_; } - - // Return `platform_id` from client. - PjRtPlatformId platform_id() const; - - // Return `platform_name` from client. - const std::string& platform_name() const; - // A vendor-dependent string that uniquely identifies the kind of device. - const std::string& device_kind() const { return device_kind_; } + const std::string& device_kind() const override { return device_kind_; } - virtual std::string DebugString() const; - - PjRtClient* client() const { return client_; } + std::string DebugString() const override; // Transfer the given literal to the infeed queue of the given localdevice. - virtual Status TransferToInfeed(const LiteralSlice& literal) const; + Status TransferToInfeed(const LiteralSlice& literal) const override; // Transfer and return a value of the given shape from the outfeed of the // given device. - virtual StatusOr TransferFromOutfeed(const Shape& shape) const; + StatusOr TransferFromOutfeed(const Shape& shape) const override; private: const int id_; - const int local_device_id_; // -1 means not local. + const int device_ordinal_; // -1 means not local. const std::unique_ptr local_device_state_; const int host_id_; const std::string device_kind_; @@ -178,86 +219,62 @@ class PjRtExecutable; // alive as long as any of the other runtime objects are alive. class PjRtClient { public: - // `allocator` may null, in which case the platform default allocator is used. - explicit PjRtClient( - std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, - std::unique_ptr allocator, - std::unique_ptr host_memory_allocator, - bool should_stage_host_to_device_transfers, - std::unique_ptr gpu_run_options); virtual ~PjRtClient() = default; + // TODO(zhangqiaorjc): Rename to task_id. + // Return the task id of this client. In single-task setting, always 0. + virtual int host_id() const = 0; + + // Return the number of devices in the entire computation. In multi-headed + // client setting, some are addressable by this client, some are not. In a + // single-client setting, this is equal to the number of addressable devices. + virtual int device_count() const = 0; + + // Return number of addressable devices. Addressable devices are those that + // the client can issue commands to. + virtual int addressable_device_count() const = 0; + + // Return all devices in the entire computation, including addressable and + // non-addressable devices. + virtual absl::Span devices() const = 0; + + // TODO(zhangqiaorjc): Rename to addressable_devices. + // Return only addressable devices. + virtual absl::Span local_devices() const = 0; + + // Lookup any PjRtDevice for a given PjRtDevice::id(). + virtual StatusOr LookupDevice(int device_id) const = 0; + + // Return an addressable PjRtDevice for a given + // PjRtDevice::local_hardware_id(). + virtual StatusOr LookupAddressableDevice( + int local_hardware_id) const = 0; + + // Return an ID that identifies the platform (CPU/GPU/TPU). + virtual PjRtPlatformId platform_id() const = 0; + + // Returns a string that identifies the platform (CPU/GPU/TPU). + virtual const std::string& platform_name() const = 0; + + // Return a device-specific default device assignment, e.g., GPU and TPU may + // be different. virtual StatusOr GetDefaultDeviceAssignment( - int num_replicas, int num_partitions) const; - - int device_count() const { return devices_.size(); } - int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() const { - return devices_; - } - const std::vector& local_devices() const { - return local_devices_; - } - const std::map& id_to_device() const { - return id_to_device_; - } - int host_id() const { return host_id_; } - PjRtPlatformId platform_id() const { return platform_id_; } - const std::string& platform_name() const { return platform_name_; } - - LocalDeviceState& device_state(int device_ordinal) const { - return *local_devices_.at(device_ordinal)->local_device_state(); - } - - // Return a local PjRtDevice for a given `local_device_id`. - virtual StatusOr LookupLocalDevice(int local_device_id) const; - - LocalClient* client() const { return client_; } - se::DeviceMemoryAllocator* allocator() const { return allocator_; } - tensorflow::Allocator* host_memory_allocator() const { - return host_memory_allocator_.get(); - } - bool should_stage_host_to_device_transfers() const { - return should_stage_host_to_device_transfers_; - } - - gpu::GpuExecutableRunOptions* gpu_run_options() const { - return gpu_run_options_.get(); - } - - tensorflow::thread::ThreadPool* h2d_transfer_pool() { - return &h2d_transfer_pool_; - } - - // Most platforms expect device-to-device transfers to be enqueued on the - // source d2d stream, but some platforms use the destination d2d stream. This - // function specifies which one the platform expects. - virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } - - // Generates a unique fingerprint for `executable`. - virtual StatusOr> ExecutableFingerprint( - const PjRtExecutable& executable) const { - return absl::optional(); - } + int num_replicas, int num_partitions) const = 0; // Returns a backend-specific HLO cost analysis visitor. - virtual std::unique_ptr GetHloCostAnalysis(); + virtual std::unique_ptr GetHloCostAnalysis() = 0; + // Compile `computation` with given `options`. virtual StatusOr> Compile( - const XlaComputation& computation, CompileOptions options); + const XlaComputation& computation, CompileOptions options) = 0; + + // Generates a unique fingerprint for `executable`, may be absl::nullopt. + virtual StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const = 0; // Creates a buffer on the device without initializing or copying any data. - // An optional `definition_event` may be speficied that can be used to - // ensure the buffer isn't referenced until some external mechanism has - // initialized the data. - // NOTE: The sequencing mechanism is not guaranteed to be supported by all - // future backends and so callers should avoid wherever possible. virtual StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device); - virtual StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device, - std::shared_ptr definition_event); + const Shape& shape, PjRtDevice* device) = 0; // Describes the semantics the caller to BufferFromHostBuffer expects from the // runtime, in a total order from most restrictive to least restrictive. @@ -289,13 +306,13 @@ class PjRtClient { virtual StatusOr> BufferFromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, - std::shared_ptr buffer_reference, PjRtDevice* device); + std::shared_ptr buffer_reference, PjRtDevice* device) = 0; // Note that literal must remain in scope until the transfer has completed, so // the caller should, for example, wait for BlockHostUntilReady() completes on // the return value before letting literal go out of scope. virtual StatusOr> BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device); + const LiteralSlice& literal, PjRtDevice* device) = 0; // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact @@ -308,18 +325,140 @@ class PjRtClient { // buffers will become ready until *all* of the sends have completed. virtual void MakeCrossHostReceiveBuffers( absl::Span shapes, PjRtDevice* device, - PjRtCrossHostRecvNotifier&& notifier); + PjRtCrossHostRecvNotifier&& notifier) = 0; - virtual StatusOr CreateChannelHandle() { + // Create ChannelHandles for XLA send/recv. + virtual StatusOr CreateChannelHandle() = 0; + virtual StatusOr CreateDeviceToHostChannelHandle() = 0; + virtual StatusOr CreateHostToDeviceChannelHandle() = 0; +}; + +class PjRtStreamExecutorClient : public PjRtClient { + public: + // `allocator` may null, in which case the platform default allocator is used. + explicit PjRtStreamExecutorClient( + std::string platform_name, LocalClient* client, + std::vector> devices, + int host_id, std::unique_ptr allocator, + std::unique_ptr host_memory_allocator, + bool should_stage_host_to_device_transfers, + std::unique_ptr gpu_run_options); + ~PjRtStreamExecutorClient() override = default; + + int host_id() const override { return host_id_; } + + int device_count() const override { return devices_.size(); } + int addressable_device_count() const override { + return local_devices_.size(); + } + absl::Span devices() const override { return devices_; } + absl::Span local_devices() const override { + return local_devices_; + } + + StatusOr LookupDevice(int device_id) const override { + auto it = id_to_device_.find(device_id); + if (it != id_to_device_.end()) { + return it->second; + } + return InvalidArgument("No matching device found for device_id %d", + device_id); + } + + StatusOr LookupAddressableDevice( + int local_hardware_id) const override; + + PjRtPlatformId platform_id() const override { return platform_id_; } + const std::string& platform_name() const override { return platform_name_; } + + // Most platforms expect device-to-device transfers to be enqueued on the + // source d2d stream, but some platforms use the destination d2d stream. This + // function specifies which one the platform expects. + virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } + + StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; + + // Generates a unique fingerprint for `executable`. + StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const override { + return absl::optional(); + } + + // Returns a backend-specific HLO cost analysis visitor. + std::unique_ptr GetHloCostAnalysis() override; + + // Creates a buffer on the device without initializing or copying any data. + // An optional `definition_event` may be speficied that can be used to + // ensure the buffer isn't referenced until some external mechanism has + // initialized the data. + // NOTE: The sequencing mechanism is not guaranteed to be supported by all + // future backends and so callers should avoid wherever possible. + StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device) override; + virtual StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device, + std::shared_ptr definition_event); + + StatusOr> BufferFromHostBuffer( + const void* data, const Shape& shape, + HostBufferSemantics host_buffer_semantics, + std::shared_ptr buffer_reference, PjRtDevice* device) override; + + // Note that literal must remain in scope until the transfer has completed, so + // the caller should, for example, wait for BlockHostUntilReady() completes on + // the return value before letting literal go out of scope. + StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) override; + + // Asynchronously makes a vector of PjRtBuffers that can be used to receive + // cross host transfers using `client` on `device'. `shapes` must be the exact + // shapes, with identical layouts, corresponding to the buffers that will be + // sent. When resources for the transfer are available, notifier will be + // called with a vector of PjRtCrossHostRecvBuffer structs, one for each + // shape in `shapes`. Each struct contains a buffer that will contain the + // received value, and an opaque string that should be transmitted to the + // sending host and used in a call to CopyToRemoteDevice. None of the recv + // buffers will become ready until *all* of the sends have completed. + void MakeCrossHostReceiveBuffers( + absl::Span shapes, PjRtDevice* device, + PjRtCrossHostRecvNotifier&& notifier) override; + + StatusOr CreateChannelHandle() override { return client()->CreateChannelHandle(); } - virtual StatusOr CreateDeviceToHostChannelHandle() { + StatusOr CreateDeviceToHostChannelHandle() override { return client()->CreateDeviceToHostChannelHandle(); } - virtual StatusOr CreateHostToDeviceChannelHandle() { + StatusOr CreateHostToDeviceChannelHandle() override { return client()->CreateHostToDeviceChannelHandle(); } + LocalDeviceState& device_state(int device_ordinal) const { + return *tensorflow::down_cast( + local_devices_.at(device_ordinal)) + ->local_device_state(); + } + LocalClient* client() const { return client_; } + se::DeviceMemoryAllocator* allocator() const { return allocator_; } + tensorflow::Allocator* host_memory_allocator() const { + return host_memory_allocator_.get(); + } + bool should_stage_host_to_device_transfers() const { + return should_stage_host_to_device_transfers_; + } + + gpu::GpuExecutableRunOptions* gpu_run_options() const { + return gpu_run_options_.get(); + } + + tensorflow::thread::ThreadPool* h2d_transfer_pool() { + return &h2d_transfer_pool_; + } + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -342,7 +481,9 @@ class PjRtClient { std::unique_ptr host_memory_allocator_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> owned_devices_; + // Pointers to `owned_devices_`. + std::vector devices_; // Maps Device::id() to the corresponding Device. Includes all devices. std::map id_to_device_; // Local devices indexed by local device ordinal. @@ -509,7 +650,7 @@ class PjRtBuffer { private: friend class PjRtBuffer; - friend class PjRtClient; + friend class PjRtStreamExecutorClient; // Helper struct that makes it possible to move a ScopedHold through a // closure. @@ -769,7 +910,7 @@ class PjRtExecutable { virtual PjRtClient* client() const = 0; // Unique name for this executable, e.g., HloModule name. - virtual const string& name() const = 0; + virtual const std::string& name() const = 0; virtual int num_replicas() const = 0; @@ -791,6 +932,7 @@ class PjRtExecutable { virtual absl::Span addressable_device_logical_ids() const = 0; + // An addressable_device is one which the client can issue commands to. // addressable_devices()[i] is the Device to which // addressable_device_logical_ids()[i] is assigned. virtual absl::Span addressable_devices() const = 0; @@ -833,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable { bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector addressable_device_logical_ids, - std::vector addressable_devices, PjRtClient* client); + std::vector addressable_devices, + PjRtStreamExecutorClient* client); ~PjRtStreamExecutorExecutable() override = default; - PjRtClient* client() const override { return client_; } + PjRtStreamExecutorClient* client() const override { return client_; } - const string& name() const override; + const std::string& name() const override; int num_replicas() const override { return executables_[0]->build_options().num_replicas(); @@ -898,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable { } private: - friend class PjRtClient; + friend class PjRtStreamExecutorClient; // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. Status SetUpDonation(bool tuple_inputs); @@ -933,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable { // Create shared pointers so we can free them after the execution: with // asynchronous execution, the process being executed can outlive the // executable itself. - PjRtClient* const client_; + PjRtStreamExecutorClient* const client_; // One executable per partition. std::vector> executables_; // Per-executable set of parameters that have any aliased buffers and thus diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 9d6fa929a19..8222874a229 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice( return Status::OK(); } -class PjRtTpuClient : public PjRtClient { +class PjRtTpuClient : public PjRtStreamExecutorClient { public: PjRtTpuClient(LocalClient* client, - std::vector> devices, int host_id); + std::vector> devices, + int host_id); StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient { const PjRtExecutable& executable) const override; }; -PjRtTpuClient::PjRtTpuClient(LocalClient* client, - std::vector> devices, - int host_id) - : PjRtClient(kTpuName, client, std::move(devices), host_id, - /*allocator=*/nullptr, - /*host_memory_allocator=*/nullptr, - /*should_stage_host_to_device_transfers=*/false, - /*gpu_run_options=*/nullptr) {} +PjRtTpuClient::PjRtTpuClient( + LocalClient* client, + std::vector> devices, int host_id) + : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id, + /*allocator=*/nullptr, + /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr) {} StatusOr PjRtTpuClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { @@ -128,7 +129,8 @@ StatusOr PjRtTpuClient::GetDefaultDeviceAssignment( num_partitions); } // Fallback to default global device assignment if we can't run locally. - return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); + return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas, + num_partitions); } StatusOr> PjRtTpuClient::ExecutableFingerprint( @@ -152,10 +154,10 @@ StatusOr> PjRtTpuClient::ExecutableFingerprint( return absl::optional(tpu_executable->fingerprint()); } -StatusOr>> GetTpuDevices( +StatusOr>> GetTpuDevices( LocalClient* client, std::vector> local_device_states) { - std::vector> devices; + std::vector> devices; tf_tpu::TpuTopologyExternal topology = tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index cdc68bc9606..f17d82a270e 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -26,14 +26,14 @@ limitations under the License. namespace xla { -class PjRtTpuDevice : public PjRtDevice { +class PjRtTpuDevice : public PjRtStreamExecutorDevice { public: PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, std::unique_ptr local_device_state, int host_id, const std::array& coords, std::string device_kind) - : PjRtDevice(core.Id(), std::move(local_device_state), - std::move(device_kind), host_id), + : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state), + std::move(device_kind), host_id), core_(core), coords_(coords) {} diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 050f3009bfc..e8a61c0e916 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -97,13 +97,13 @@ cc_library( name = "types", srcs = ["types.cc"], hdrs = ["types.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], deps = [ - ":bfloat16", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -113,6 +113,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:lib", + "//tensorflow/python:bfloat16_lib", "//third_party/py/numpy:headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -158,42 +159,6 @@ cc_library( ], ) -cc_library( - name = "bfloat16", - srcs = ["bfloat16.cc"], - hdrs = ["bfloat16.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core/platform:bfloat16", - "//tensorflow/core/platform:logging", - "//third_party/py/numpy:headers", - "//third_party/python_runtime:headers", # buildcleaner: keep - "@com_google_absl//absl/strings", - "@pybind11", - ], -) - -py_test( - name = "bfloat16_test", - srcs = ["bfloat16_test.py"], - main = "bfloat16_test.py", - python_version = "PY3", - tags = ["no_oss"], - deps = [ - ":xla_client", - ":xla_extension", - "@absl_py//absl/testing:absltest", - "@absl_py//absl/testing:parameterized", - ] + xla_py_test_deps(), -) - cc_library( name = "py_client", srcs = [ @@ -206,6 +171,7 @@ cc_library( "py_client.h", "py_executable.h", ], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -232,6 +198,7 @@ cc_library( name = "dlpack", srcs = ["dlpack.cc"], hdrs = ["dlpack.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -263,6 +230,7 @@ cc_library( name = "jax_jit", srcs = ["jax_jit.cc"], hdrs = ["jax_jit.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -292,6 +260,7 @@ cc_library( name = "ops", srcs = ["ops.cc"], hdrs = ["ops.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -356,6 +325,7 @@ cc_library( name = "outfeed_receiver_py", srcs = ["outfeed_receiver_py.cc"], hdrs = ["outfeed_receiver_py.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -379,12 +349,14 @@ cc_library( name = "pytree", srcs = ["pytree.cc"], hdrs = ["pytree.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], deps = [ + ":types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", @@ -435,6 +407,7 @@ cc_library( name = "xla_compiler", srcs = ["xla_compiler.cc"], hdrs = ["xla_compiler.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -481,7 +454,6 @@ pybind_extension( features = ["-use_header_modules"], module_name = "xla_extension", deps = [ - ":bfloat16", ":dlpack", ":jax_jit", ":ops", @@ -534,6 +506,7 @@ pybind_extension( # without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does # not require Tensorflow. "//tensorflow/core:lib_internal_impl", # buildcleaner: keep + "//tensorflow/python:bfloat16_lib", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:platform", ] + select({ diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc deleted file mode 100644 index 5f96c494c25..00000000000 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ /dev/null @@ -1,1576 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/python/bfloat16.h" - -#include -#include -// Place `` before to avoid a build failure in macOS. -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/bfloat16.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace { - -namespace py = pybind11; - -struct PyDecrefDeleter { - void operator()(PyObject* p) const { Py_DECREF(p); } -}; - -// Safe container for an owned PyObject. On destruction, the reference count of -// the contained object will be decremented. -using Safe_PyObjectPtr = std::unique_ptr; -Safe_PyObjectPtr make_safe(PyObject* object) { - return Safe_PyObjectPtr(object); -} - -bool PyLong_CheckNoOverflow(PyObject* object) { - if (!PyLong_Check(object)) { - return false; - } - int overflow = 0; - PyLong_AsLongAndOverflow(object, &overflow); - return (overflow == 0); -} - -// Registered numpy type ID. Global variable populated by the registration code. -// Protected by the GIL. -int npy_bfloat16 = -1; - -// Forward declaration. -extern PyTypeObject PyBfloat16_Type; - -// Representation of a Python bfloat16 object. -struct PyBfloat16 { - PyObject_HEAD; // Python object header - bfloat16 value; -}; - -// Returns true if 'object' is a PyBfloat16. -bool PyBfloat16_Check(PyObject* object) { - return PyObject_IsInstance(object, - reinterpret_cast(&PyBfloat16_Type)); -} - -// Extracts the value of a PyBfloat16 object. -bfloat16 PyBfloat16_Bfloat16(PyObject* object) { - return reinterpret_cast(object)->value; -} - -// Constructs a PyBfloat16 object from a bfloat16. -Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) { - Safe_PyObjectPtr ref = - make_safe(PyBfloat16_Type.tp_alloc(&PyBfloat16_Type, 0)); - PyBfloat16* p = reinterpret_cast(ref.get()); - if (p) { - p->value = x; - } - return ref; -} - -// Converts a Python object to a bfloat16 value. Returns true on success, -// returns false and reports a Python error on failure. -bool CastToBfloat16(PyObject* arg, bfloat16* output) { - if (PyBfloat16_Check(arg)) { - *output = PyBfloat16_Bfloat16(arg); - return true; - } - if (PyFloat_Check(arg)) { - double d = PyFloat_AsDouble(arg); - if (PyErr_Occurred()) { - return false; - } - // TODO(phawkins): check for overflow - *output = bfloat16(d); - return true; - } - if (PyLong_CheckNoOverflow(arg)) { - long l = PyLong_AsLong(arg); // NOLINT - if (PyErr_Occurred()) { - return false; - } - // TODO(phawkins): check for overflow - *output = bfloat16(static_cast(l)); - return true; - } - if (PyArray_IsScalar(arg, Half)) { - Eigen::half f; - PyArray_ScalarAsCtype(arg, &f); - *output = bfloat16(f); - return true; - } - if (PyArray_IsScalar(arg, Float)) { - float f; - PyArray_ScalarAsCtype(arg, &f); - *output = bfloat16(f); - return true; - } - if (PyArray_IsScalar(arg, Double)) { - double f; - PyArray_ScalarAsCtype(arg, &f); - *output = bfloat16(f); - return true; - } - if (PyArray_IsZeroDim(arg)) { - Safe_PyObjectPtr ref; - PyArrayObject* arr = reinterpret_cast(arg); - if (PyArray_TYPE(arr) != npy_bfloat16) { - ref = make_safe(PyArray_Cast(arr, npy_bfloat16)); - if (PyErr_Occurred()) { - return false; - } - arg = ref.get(); - arr = reinterpret_cast(arg); - } - *output = *reinterpret_cast(PyArray_DATA(arr)); - return true; - } - return false; -} - -bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) { - if (PyBfloat16_Check(arg)) { - *output = PyBfloat16_Bfloat16(arg); - return true; - } - return false; -} - -// Converts a PyBfloat16 into a PyFloat. -PyObject* PyBfloat16_Float(PyObject* self) { - bfloat16 x = PyBfloat16_Bfloat16(self); - return PyFloat_FromDouble(static_cast(x)); -} - -// Converts a PyBfloat16 into a PyInt. -PyObject* PyBfloat16_Int(PyObject* self) { - bfloat16 x = PyBfloat16_Bfloat16(self); - long y = static_cast(x); // NOLINT - return PyLong_FromLong(y); -} - -// Negates a PyBfloat16. -PyObject* PyBfloat16_Negative(PyObject* self) { - bfloat16 x = PyBfloat16_Bfloat16(self); - return PyBfloat16_FromBfloat16(-x).release(); -} - -PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) { - bfloat16 x, y; - if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) { - return PyBfloat16_FromBfloat16(x + y).release(); - } - return PyArray_Type.tp_as_number->nb_add(a, b); -} - -PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) { - bfloat16 x, y; - if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) { - return PyBfloat16_FromBfloat16(x - y).release(); - } - return PyArray_Type.tp_as_number->nb_subtract(a, b); -} - -PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) { - bfloat16 x, y; - if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) { - return PyBfloat16_FromBfloat16(x * y).release(); - } - return PyArray_Type.tp_as_number->nb_multiply(a, b); -} - -PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) { - bfloat16 x, y; - if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) { - return PyBfloat16_FromBfloat16(x / y).release(); - } - return PyArray_Type.tp_as_number->nb_true_divide(a, b); -} - -// Python number methods for PyBfloat16 objects. -PyNumberMethods PyBfloat16_AsNumber = { - PyBfloat16_Add, // nb_add - PyBfloat16_Subtract, // nb_subtract - PyBfloat16_Multiply, // nb_multiply - nullptr, // nb_remainder - nullptr, // nb_divmod - nullptr, // nb_power - PyBfloat16_Negative, // nb_negative - nullptr, // nb_positive - nullptr, // nb_absolute - nullptr, // nb_nonzero - nullptr, // nb_invert - nullptr, // nb_lshift - nullptr, // nb_rshift - nullptr, // nb_and - nullptr, // nb_xor - nullptr, // nb_or - PyBfloat16_Int, // nb_int - nullptr, // reserved - PyBfloat16_Float, // nb_float - - nullptr, // nb_inplace_add - nullptr, // nb_inplace_subtract - nullptr, // nb_inplace_multiply - nullptr, // nb_inplace_remainder - nullptr, // nb_inplace_power - nullptr, // nb_inplace_lshift - nullptr, // nb_inplace_rshift - nullptr, // nb_inplace_and - nullptr, // nb_inplace_xor - nullptr, // nb_inplace_or - - nullptr, // nb_floor_divide - PyBfloat16_TrueDivide, // nb_true_divide - nullptr, // nb_inplace_floor_divide - nullptr, // nb_inplace_true_divide - nullptr, // nb_index -}; - -// Constructs a new PyBfloat16. -PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) { - if (kwds && PyDict_Size(kwds)) { - PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments"); - return nullptr; - } - Py_ssize_t size = PyTuple_Size(args); - if (size != 1) { - PyErr_SetString(PyExc_TypeError, - "expected number as argument to bfloat16 constructor"); - return nullptr; - } - PyObject* arg = PyTuple_GetItem(args, 0); - - bfloat16 value; - if (PyBfloat16_Check(arg)) { - Py_INCREF(arg); - return arg; - } else if (CastToBfloat16(arg, &value)) { - return PyBfloat16_FromBfloat16(value).release(); - } else if (PyArray_Check(arg)) { - PyArrayObject* arr = reinterpret_cast(arg); - if (PyArray_TYPE(arr) != npy_bfloat16) { - return PyArray_Cast(arr, npy_bfloat16); - } else { - Py_INCREF(arg); - return arg; - } - } - PyErr_Format(PyExc_TypeError, "expected number, got %s", - arg->ob_type->tp_name); - return nullptr; -} - -// Comparisons on PyBfloat16s. -PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) { - bfloat16 x, y; - if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) { - return PyGenericArrType_Type.tp_richcompare(a, b, op); - } - bool result; - switch (op) { - case Py_LT: - result = x < y; - break; - case Py_LE: - result = x <= y; - break; - case Py_EQ: - result = x == y; - break; - case Py_NE: - result = x != y; - break; - case Py_GT: - result = x > y; - break; - case Py_GE: - result = x >= y; - break; - default: - LOG(FATAL) << "Invalid op type " << op; - } - return PyBool_FromLong(result); -} - -// Implementation of repr() for PyBfloat16. -PyObject* PyBfloat16_Repr(PyObject* self) { - bfloat16 x = reinterpret_cast(self)->value; - std::string v = absl::StrCat(static_cast(x)); - return PyUnicode_FromString(v.c_str()); -} - -// Implementation of str() for PyBfloat16. -PyObject* PyBfloat16_Str(PyObject* self) { - bfloat16 x = reinterpret_cast(self)->value; - std::string v = absl::StrCat(static_cast(x)); - return PyUnicode_FromString(v.c_str()); -} - -// Hash function for PyBfloat16. We use the identity function, which is a weak -// hash function. -Py_hash_t PyBfloat16_Hash(PyObject* self) { - bfloat16 x = reinterpret_cast(self)->value; - return x.value; -} - -// Python type for PyBfloat16 objects. -PyTypeObject PyBfloat16_Type = { - PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16", // tp_name - sizeof(PyBfloat16), // tp_basicsize - 0, // tp_itemsize - nullptr, // tp_dealloc -#if PY_VERSION_HEX < 0x03080000 - nullptr, // tp_print -#else - 0, // tp_vectorcall_offset -#endif - nullptr, // tp_getattr - nullptr, // tp_setattr - nullptr, // tp_compare / tp_reserved - PyBfloat16_Repr, // tp_repr - &PyBfloat16_AsNumber, // tp_as_number - nullptr, // tp_as_sequence - nullptr, // tp_as_mapping - PyBfloat16_Hash, // tp_hash - nullptr, // tp_call - PyBfloat16_Str, // tp_str - nullptr, // tp_getattro - nullptr, // tp_setattro - nullptr, // tp_as_buffer - // tp_flags - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, - "bfloat16 floating-point values", // tp_doc - nullptr, // tp_traverse - nullptr, // tp_clear - PyBfloat16_RichCompare, // tp_richcompare - 0, // tp_weaklistoffset - nullptr, // tp_iter - nullptr, // tp_iternext - nullptr, // tp_methods - nullptr, // tp_members - nullptr, // tp_getset - nullptr, // tp_base - nullptr, // tp_dict - nullptr, // tp_descr_get - nullptr, // tp_descr_set - 0, // tp_dictoffset - nullptr, // tp_init - nullptr, // tp_alloc - PyBfloat16_New, // tp_new - nullptr, // tp_free - nullptr, // tp_is_gc - nullptr, // tp_bases - nullptr, // tp_mro - nullptr, // tp_cache - nullptr, // tp_subclasses - nullptr, // tp_weaklist - nullptr, // tp_del - 0, // tp_version_tag -}; - -// Numpy support - -PyArray_ArrFuncs NPyBfloat16_ArrFuncs; - -PyArray_Descr NPyBfloat16_Descr = { - PyObject_HEAD_INIT(nullptr) // - /*typeobj=*/ - (&PyBfloat16_Type), - // We must register bfloat16 with a kind other than "f", because numpy - // considers two types with the same kind and size to be equal, but - // float16 != bfloat16. - // The downside of this is that NumPy scalar promotion does not work with - // bfloat16 values. - /*kind=*/'V', - // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type - // character is unique. - /*type=*/'E', - /*byteorder=*/'=', - /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, - /*type_num=*/0, - /*elsize=*/sizeof(bfloat16), - /*alignment=*/alignof(bfloat16), - /*subarray=*/nullptr, - /*fields=*/nullptr, - /*names=*/nullptr, - /*f=*/&NPyBfloat16_ArrFuncs, - /*metadata=*/nullptr, - /*c_metadata=*/nullptr, - /*hash=*/-1, // -1 means "not computed yet". -}; - -// Implementations of NumPy array methods. - -PyObject* NPyBfloat16_GetItem(void* data, void* arr) { - bfloat16 x; - memcpy(&x, data, sizeof(bfloat16)); - return PyBfloat16_FromBfloat16(x).release(); -} - -int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) { - bfloat16 x; - if (!CastToBfloat16(item, &x)) { - PyErr_Format(PyExc_TypeError, "expected number, got %s", - item->ob_type->tp_name); - return -1; - } - memcpy(data, &x, sizeof(bfloat16)); - return 0; -} - -void ByteSwap16(void* value) { - char* p = reinterpret_cast(value); - std::swap(p[0], p[1]); -} - -int NPyBfloat16_Compare(const void* a, const void* b, void* arr) { - bfloat16 x; - memcpy(&x, a, sizeof(bfloat16)); - - bfloat16 y; - memcpy(&y, b, sizeof(bfloat16)); - - if (x < y) { - return -1; - } - if (y < x) { - return 1; - } - // NaNs sort to the end. - if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) { - return -1; - } - if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) { - return 1; - } - return 0; -} - -void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv, - npy_intp sstride, npy_intp n, int swap, void* arr) { - char* dst = reinterpret_cast(dstv); - char* src = reinterpret_cast(srcv); - if (!src) { - return; - } - if (swap) { - for (npy_intp i = 0; i < n; i++) { - char* r = dst + dstride * i; - memcpy(r, src + sstride * i, sizeof(uint16_t)); - ByteSwap16(r); - } - } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) { - memcpy(dst, src, n * sizeof(uint16_t)); - } else { - for (npy_intp i = 0; i < n; i++) { - memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t)); - } - } -} - -void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) { - if (!src) { - return; - } - memcpy(dst, src, sizeof(uint16_t)); - if (swap) { - ByteSwap16(dst); - } -} - -npy_bool NPyBfloat16_NonZero(void* data, void* arr) { - bfloat16 x; - memcpy(&x, data, sizeof(x)); - return x != static_cast(0); -} - -int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { - bfloat16* const buffer = reinterpret_cast(buffer_raw); - const float start(buffer[0]); - const float delta = static_cast(buffer[1]) - start; - for (npy_intp i = 2; i < length; ++i) { - buffer[i] = static_cast(start + i * delta); - } - return 0; -} - -void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2, - void* op, npy_intp n, void* arr) { - char* c1 = reinterpret_cast(ip1); - char* c2 = reinterpret_cast(ip2); - float acc = 0.0f; - for (npy_intp i = 0; i < n; ++i) { - bfloat16* const b1 = reinterpret_cast(c1); - bfloat16* const b2 = reinterpret_cast(c2); - acc += static_cast(*b1) * static_cast(*b2); - c1 += is1; - c2 += is2; - } - bfloat16* out = reinterpret_cast(op); - *out = static_cast(acc); -} - -int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) { - bfloat16 b1 = *reinterpret_cast(v1); - bfloat16 b2 = *reinterpret_cast(v2); - if (b1 < b2) { - return -1; - } - if (b1 > b2) { - return 1; - } - return 0; -} - -int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, - void* arr) { - const bfloat16* bdata = reinterpret_cast(data); - float max_val = -std::numeric_limits::infinity(); - for (npy_intp i = 0; i < n; ++i) { - if (static_cast(bdata[i]) > max_val) { - max_val = static_cast(bdata[i]); - *max_ind = i; - } - } - return 0; -} - -int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, - void* arr) { - const bfloat16* bdata = reinterpret_cast(data); - float min_val = std::numeric_limits::infinity(); - for (npy_intp i = 0; i < n; ++i) { - if (static_cast(bdata[i]) < min_val) { - min_val = static_cast(bdata[i]); - *min_ind = i; - } - } - return 0; -} - -// NumPy casts - -template -struct TypeDescriptor { - // typedef ... T; // Representation type in memory for NumPy values of type - // static int Dtype() { return NPY_...; } // Numpy type number for T. -}; - -template <> -struct TypeDescriptor { - typedef bfloat16 T; - static int Dtype() { return npy_bfloat16; } -}; - -template <> -struct TypeDescriptor { - typedef uint8 T; - static int Dtype() { return NPY_UINT8; } -}; - -template <> -struct TypeDescriptor { - typedef uint16 T; - static int Dtype() { return NPY_UINT16; } -}; - -// We register "int", "long", and "long long" types for portability across -// Linux, where "int" and "long" are the same type, and Windows, where "long" -// and "longlong" are the same type. -template <> -struct TypeDescriptor { - typedef unsigned int T; - static int Dtype() { return NPY_UINT; } -}; - -template <> -struct TypeDescriptor { // NOLINT - typedef unsigned long T; // NOLINT - static int Dtype() { return NPY_ULONG; } -}; - -template <> -struct TypeDescriptor { // NOLINT - typedef unsigned long long T; // NOLINT - static int Dtype() { return NPY_ULONGLONG; } -}; - -template <> -struct TypeDescriptor { - typedef int8 T; - static int Dtype() { return NPY_INT8; } -}; - -template <> -struct TypeDescriptor { - typedef int16 T; - static int Dtype() { return NPY_INT16; } -}; - -template <> -struct TypeDescriptor { - typedef int T; - static int Dtype() { return NPY_INT; } -}; - -template <> -struct TypeDescriptor { // NOLINT - typedef long T; // NOLINT - static int Dtype() { return NPY_LONG; } -}; - -template <> -struct TypeDescriptor { // NOLINT - typedef long long T; // NOLINT - static int Dtype() { return NPY_LONGLONG; } -}; - -template <> -struct TypeDescriptor { - typedef int8 T; - static int Dtype() { return NPY_BOOL; } -}; - -template <> -struct TypeDescriptor { - typedef Eigen::half T; - static int Dtype() { return NPY_HALF; } -}; - -template <> -struct TypeDescriptor { - typedef float T; - static int Dtype() { return NPY_FLOAT; } -}; - -template <> -struct TypeDescriptor { - typedef double T; - static int Dtype() { return NPY_DOUBLE; } -}; - -template <> -struct TypeDescriptor { - typedef complex64 T; - static int Dtype() { return NPY_COMPLEX64; } -}; - -template <> -struct TypeDescriptor { - typedef complex128 T; - static int Dtype() { return NPY_COMPLEX128; } -}; - -// Performs a NumPy array cast from type 'From' to 'To'. -template -void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, - void* toarr) { - const auto* from = - reinterpret_cast::T*>(from_void); - auto* to = reinterpret_cast::T*>(to_void); - for (npy_intp i = 0; i < n; ++i) { - to[i] = - static_cast::T>(static_cast(from[i])); - } -} - -// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy -// type corresponding to 'T'. If 'cast_is_safe', registers that bfloat16 can be -// safely coerced to T. -template -bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) { - if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16, - NPyCast) < 0) { - return false; - } - if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type, - NPyCast) < 0) { - return false; - } - if (cast_is_safe && PyArray_RegisterCanCast(&NPyBfloat16_Descr, numpy_type, - NPY_NOSCALAR) < 0) { - return false; - } - return true; -} - -template -struct UnaryUFunc { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype()}; - } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - char* o = args[1]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - *reinterpret_cast::T*>(o) = Functor()(x); - i0 += steps[0]; - o += steps[1]; - } - } -}; - -template -struct UnaryUFunc2 { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - char* o0 = args[1]; - char* o1 = args[2]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - std::tie(*reinterpret_cast::T*>(o0), - *reinterpret_cast::T*>(o1)) = - Functor()(x); - i0 += steps[0]; - o0 += steps[1]; - o1 += steps[2]; - } - } -}; - -template -struct BinaryUFunc { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o = args[2]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - auto y = *reinterpret_cast::T*>(i1); - *reinterpret_cast::T*>(o) = - Functor()(x, y); - i0 += steps[0]; - i1 += steps[1]; - o += steps[2]; - } - } -}; - -template -struct BinaryUFunc2 { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o = args[2]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - auto y = - *reinterpret_cast::T*>(i1); - *reinterpret_cast::T*>(o) = - Functor()(x, y); - i0 += steps[0]; - i1 += steps[1]; - o += steps[2]; - } - } -}; - -template -bool RegisterUFunc(PyObject* numpy, const char* name) { - std::vector types = UFunc::Types(); - PyUFuncGenericFunction fn = - reinterpret_cast(UFunc::Call); - Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name)); - if (!ufunc_obj) { - return false; - } - PyUFuncObject* ufunc = reinterpret_cast(ufunc_obj.get()); - if (static_cast(types.size()) != ufunc->nargs) { - PyErr_Format(PyExc_AssertionError, - "ufunc %s takes %d arguments, loop takes %lu", name, - ufunc->nargs, types.size()); - return false; - } - if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn, - const_cast(types.data()), - nullptr) < 0) { - return false; - } - return true; -} - -namespace ufuncs { - -struct Add { - bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; } -}; -struct Subtract { - bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; } -}; -struct Multiply { - bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; } -}; -struct TrueDivide { - bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; } -}; - -std::pair divmod(float a, float b) { - if (b == 0.0f) { - float nan = std::numeric_limits::quiet_NaN(); - return {nan, nan}; - } - float mod = std::fmod(a, b); - float div = (a - mod) / b; - if (mod != 0.0f) { - if ((b < 0.0f) != (mod < 0.0f)) { - mod += b; - div -= 1.0f; - } - } else { - mod = std::copysign(0.0f, b); - } - - float floordiv; - if (div != 0.0f) { - floordiv = std::floor(div); - if (div - floordiv > 0.5f) { - floordiv += 1.0f; - } - } else { - floordiv = std::copysign(0.0f, a / b); - } - return {floordiv, mod}; -} - -struct FloorDivide { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16(divmod(static_cast(a), static_cast(b)).first); - } -}; -struct Remainder { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16( - divmod(static_cast(a), static_cast(b)).second); - } -}; -struct DivmodUFunc { - static std::vector Types() { - return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16}; - } - static void Call(char** args, npy_intp* dimensions, npy_intp* steps, - void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o0 = args[2]; - char* o1 = args[3]; - for (npy_intp k = 0; k < *dimensions; k++) { - bfloat16 x = *reinterpret_cast(i0); - bfloat16 y = *reinterpret_cast(i1); - float floordiv, mod; - std::tie(floordiv, mod) = - divmod(static_cast(x), static_cast(y)); - *reinterpret_cast(o0) = bfloat16(floordiv); - *reinterpret_cast(o1) = bfloat16(mod); - i0 += steps[0]; - i1 += steps[1]; - o0 += steps[2]; - o1 += steps[3]; - } - } -}; -struct Fmod { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16(std::fmod(static_cast(a), static_cast(b))); - } -}; -struct Negative { - bfloat16 operator()(bfloat16 a) { return -a; } -}; -struct Positive { - bfloat16 operator()(bfloat16 a) { return a; } -}; -struct Power { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16(std::pow(static_cast(a), static_cast(b))); - } -}; -struct Abs { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::abs(static_cast(a))); - } -}; -struct Cbrt { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::cbrt(static_cast(a))); - } -}; -struct Ceil { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::ceil(static_cast(a))); - } -}; -struct CopySign { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16( - std::copysign(static_cast(a), static_cast(b))); - } -}; -struct Exp { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::exp(static_cast(a))); - } -}; -struct Exp2 { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::exp2(static_cast(a))); - } -}; -struct Expm1 { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::expm1(static_cast(a))); - } -}; -struct Floor { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::floor(static_cast(a))); - } -}; -struct Frexp { - std::pair operator()(bfloat16 a) { - int exp; - float f = std::frexp(static_cast(a), &exp); - return {bfloat16(f), exp}; - } -}; -struct Heaviside { - bfloat16 operator()(bfloat16 bx, bfloat16 h0) { - float x = static_cast(bx); - if (Eigen::numext::isnan(x)) { - return bx; - } - if (x < 0) { - return bfloat16(0.0f); - } - if (x > 0) { - return bfloat16(1.0f); - } - return h0; // x == 0 - } -}; -struct Conjugate { - bfloat16 operator()(bfloat16 a) { return a; } -}; -struct IsFinite { - bool operator()(bfloat16 a) { return std::isfinite(static_cast(a)); } -}; -struct IsInf { - bool operator()(bfloat16 a) { return std::isinf(static_cast(a)); } -}; -struct IsNan { - bool operator()(bfloat16 a) { - return Eigen::numext::isnan(static_cast(a)); - } -}; -struct Ldexp { - bfloat16 operator()(bfloat16 a, int exp) { - return bfloat16(std::ldexp(static_cast(a), exp)); - } -}; -struct Log { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::log(static_cast(a))); - } -}; -struct Log2 { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::log2(static_cast(a))); - } -}; -struct Log10 { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::log10(static_cast(a))); - } -}; -struct Log1p { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::log1p(static_cast(a))); - } -}; -struct LogAddExp { - bfloat16 operator()(bfloat16 bx, bfloat16 by) { - float x = static_cast(bx); - float y = static_cast(by); - if (x == y) { - // Handles infinities of the same sign. - return bfloat16(x + std::log(2.0f)); - } - float out = std::numeric_limits::quiet_NaN(); - if (x > y) { - out = x + std::log1p(std::exp(y - x)); - } else if (x < y) { - out = y + std::log1p(std::exp(x - y)); - } - return bfloat16(out); - } -}; -struct LogAddExp2 { - bfloat16 operator()(bfloat16 bx, bfloat16 by) { - float x = static_cast(bx); - float y = static_cast(by); - if (x == y) { - // Handles infinities of the same sign. - return bfloat16(x + 1.0f); - } - float out = std::numeric_limits::quiet_NaN(); - if (x > y) { - out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f); - } else if (x < y) { - out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f); - } - return bfloat16(out); - } -}; -struct Modf { - std::pair operator()(bfloat16 a) { - float integral; - float f = std::modf(static_cast(a), &integral); - return {bfloat16(f), bfloat16(integral)}; - } -}; - -struct Reciprocal { - bfloat16 operator()(bfloat16 a) { - return bfloat16(1.f / static_cast(a)); - } -}; -struct Rint { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::rint(static_cast(a))); - } -}; -struct Sign { - bfloat16 operator()(bfloat16 a) { - float f(a); - if (f < 0) { - return bfloat16(-1); - } - if (f > 0) { - return bfloat16(1); - } - return a; - } -}; -struct SignBit { - bool operator()(bfloat16 a) { return std::signbit(static_cast(a)); } -}; -struct Sqrt { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::sqrt(static_cast(a))); - } -}; -struct Square { - bfloat16 operator()(bfloat16 a) { - float f(a); - return bfloat16(f * f); - } -}; -struct Trunc { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::trunc(static_cast(a))); - } -}; - -// Trigonometric functions -struct Sin { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::sin(static_cast(a))); - } -}; -struct Cos { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::cos(static_cast(a))); - } -}; -struct Tan { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::tan(static_cast(a))); - } -}; -struct Arcsin { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::asin(static_cast(a))); - } -}; -struct Arccos { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::acos(static_cast(a))); - } -}; -struct Arctan { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::atan(static_cast(a))); - } -}; -struct Arctan2 { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16(std::atan2(static_cast(a), static_cast(b))); - } -}; -struct Hypot { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - return bfloat16(std::hypot(static_cast(a), static_cast(b))); - } -}; -struct Sinh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::sinh(static_cast(a))); - } -}; -struct Cosh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::cosh(static_cast(a))); - } -}; -struct Tanh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::tanh(static_cast(a))); - } -}; -struct Arcsinh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::asinh(static_cast(a))); - } -}; -struct Arccosh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::acosh(static_cast(a))); - } -}; -struct Arctanh { - bfloat16 operator()(bfloat16 a) { - return bfloat16(std::atanh(static_cast(a))); - } -}; -struct Deg2rad { - bfloat16 operator()(bfloat16 a) { - static constexpr float radians_per_degree = M_PI / 180.0f; - return bfloat16(static_cast(a) * radians_per_degree); - } -}; -struct Rad2deg { - bfloat16 operator()(bfloat16 a) { - static constexpr float degrees_per_radian = 180.0f / M_PI; - return bfloat16(static_cast(a) * degrees_per_radian); - } -}; - -struct Eq { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; } -}; -struct Ne { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; } -}; -struct Lt { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; } -}; -struct Gt { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; } -}; -struct Le { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; } -}; -struct Ge { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; } -}; -struct Maximum { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - float fa(a), fb(b); - return Eigen::numext::isnan(fa) || fa > fb ? a : b; - } -}; -struct Minimum { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - float fa(a), fb(b); - return Eigen::numext::isnan(fa) || fa < fb ? a : b; - } -}; -struct Fmax { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - float fa(a), fb(b); - return Eigen::numext::isnan(fb) || fa > fb ? a : b; - } -}; -struct Fmin { - bfloat16 operator()(bfloat16 a, bfloat16 b) { - float fa(a), fb(b); - return Eigen::numext::isnan(fb) || fa < fb ? a : b; - } -}; - -struct LogicalNot { - npy_bool operator()(bfloat16 a) { return !a; } -}; -struct LogicalAnd { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; } -}; -struct LogicalOr { - npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; } -}; -struct LogicalXor { - npy_bool operator()(bfloat16 a, bfloat16 b) { - return static_cast(a) ^ static_cast(b); - } -}; - -struct NextAfter { - bfloat16 operator()(bfloat16 from, bfloat16 to) { - uint16_t from_as_int, to_as_int; - const uint16_t sign_mask = 1 << 15; - float from_as_float(from), to_as_float(to); - memcpy(&from_as_int, &from, sizeof(bfloat16)); - memcpy(&to_as_int, &to, sizeof(bfloat16)); - if (Eigen::numext::isnan(from_as_float) || - Eigen::numext::isnan(to_as_float)) { - return bfloat16(std::numeric_limits::quiet_NaN()); - } - if (from_as_int == to_as_int) { - return to; - } - if (from_as_float == 0) { - if (to_as_float == 0) { - return to; - } else { - // Smallest subnormal signed like `to`. - uint16_t out_int = (to_as_int & sign_mask) | 1; - bfloat16 out; - memcpy(&out, &out_int, sizeof(bfloat16)); - return out; - } - } - uint16_t from_sign = from_as_int & sign_mask; - uint16_t to_sign = to_as_int & sign_mask; - uint16_t from_abs = from_as_int & ~sign_mask; - uint16_t to_abs = to_as_int & ~sign_mask; - uint16_t magnitude_adjustment = - (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001; - uint16_t out_int = from_as_int + magnitude_adjustment; - bfloat16 out; - memcpy(&out, &out_int, sizeof(bfloat16)); - return out; - } -}; - -// TODO(phawkins): implement spacing - -} // namespace ufuncs - -} // namespace - -// Initializes the module. -bool Initialize() { - import_array1(false); - import_umath1(false); - - Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy")); - if (!numpy_str) { - return false; - } - Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get())); - if (!numpy) { - return false; - } - - PyBfloat16_Type.tp_base = &PyGenericArrType_Type; - - if (PyType_Ready(&PyBfloat16_Type) < 0) { - return false; - } - - // Initializes the NumPy descriptor. - PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs); - NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem; - NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem; - NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare; - NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; - NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; - NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; - NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; - NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc; - NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc; - NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc; - NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc; - - Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; - npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr); - if (npy_bfloat16 < 0) { - return false; - } - - // Support dtype(bfloat16) - if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype", - reinterpret_cast(&NPyBfloat16_Descr)) < - 0) { - return false; - } - - // Register casts - if (!RegisterBfloat16Cast(NPY_HALF, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_FLOAT, /*cast_is_safe=*/true)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_DOUBLE, /*cast_is_safe=*/true)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_BOOL, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_UINT8, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_UINT16, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_UINT, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_ULONG, // NOLINT - /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast( // NOLINT - NPY_ULONGLONG, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_UINT64, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_INT8, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_INT16, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_INT, /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_LONG, // NOLINT - /*cast_is_safe=*/false)) { - return false; - } - if (!RegisterBfloat16Cast( // NOLINT - NPY_LONGLONG, /*cast_is_safe=*/false)) { - return false; - } - // Following the numpy convention. imag part is dropped when converting to - // float. - if (!RegisterBfloat16Cast(NPY_COMPLEX64, /*cast_is_safe=*/true)) { - return false; - } - if (!RegisterBfloat16Cast(NPY_COMPLEX128, - /*cast_is_safe=*/true)) { - return false; - } - - bool ok = - RegisterUFunc>(numpy.get(), - "add") && - RegisterUFunc>( - numpy.get(), "subtract") && - RegisterUFunc>( - numpy.get(), "multiply") && - RegisterUFunc>( - numpy.get(), "divide") && - RegisterUFunc>( - numpy.get(), "logaddexp") && - RegisterUFunc>( - numpy.get(), "logaddexp2") && - RegisterUFunc>( - numpy.get(), "negative") && - RegisterUFunc>( - numpy.get(), "positive") && - RegisterUFunc>( - numpy.get(), "true_divide") && - RegisterUFunc>( - numpy.get(), "floor_divide") && - RegisterUFunc>(numpy.get(), - "power") && - RegisterUFunc>( - numpy.get(), "remainder") && - RegisterUFunc>( - numpy.get(), "mod") && - RegisterUFunc>(numpy.get(), - "fmod") && - RegisterUFunc(numpy.get(), "divmod") && - RegisterUFunc>(numpy.get(), - "absolute") && - RegisterUFunc>(numpy.get(), - "fabs") && - RegisterUFunc>(numpy.get(), - "rint") && - RegisterUFunc>(numpy.get(), - "sign") && - RegisterUFunc>( - numpy.get(), "heaviside") && - RegisterUFunc>( - numpy.get(), "conjugate") && - RegisterUFunc>(numpy.get(), - "exp") && - RegisterUFunc>(numpy.get(), - "exp2") && - RegisterUFunc>(numpy.get(), - "expm1") && - RegisterUFunc>(numpy.get(), - "log") && - RegisterUFunc>(numpy.get(), - "log2") && - RegisterUFunc>(numpy.get(), - "log10") && - RegisterUFunc>(numpy.get(), - "log1p") && - RegisterUFunc>(numpy.get(), - "sqrt") && - RegisterUFunc>(numpy.get(), - "square") && - RegisterUFunc>(numpy.get(), - "cbrt") && - RegisterUFunc>( - numpy.get(), "reciprocal") && - - // Trigonometric functions - RegisterUFunc>(numpy.get(), - "sin") && - RegisterUFunc>(numpy.get(), - "cos") && - RegisterUFunc>(numpy.get(), - "tan") && - RegisterUFunc>(numpy.get(), - "arcsin") && - RegisterUFunc>(numpy.get(), - "arccos") && - RegisterUFunc>(numpy.get(), - "arctan") && - RegisterUFunc>( - numpy.get(), "arctan2") && - RegisterUFunc>(numpy.get(), - "hypot") && - RegisterUFunc>(numpy.get(), - "sinh") && - RegisterUFunc>(numpy.get(), - "cosh") && - RegisterUFunc>(numpy.get(), - "tanh") && - RegisterUFunc>( - numpy.get(), "arcsinh") && - RegisterUFunc>( - numpy.get(), "arccosh") && - RegisterUFunc>( - numpy.get(), "arctanh") && - RegisterUFunc>( - numpy.get(), "deg2rad") && - RegisterUFunc>( - numpy.get(), "rad2deg") && - - // Comparison functions - RegisterUFunc>(numpy.get(), - "equal") && - RegisterUFunc>(numpy.get(), - "not_equal") && - RegisterUFunc>(numpy.get(), - "less") && - RegisterUFunc>(numpy.get(), - "greater") && - RegisterUFunc>(numpy.get(), - "less_equal") && - RegisterUFunc>(numpy.get(), - "greater_equal") && - RegisterUFunc>( - numpy.get(), "maximum") && - RegisterUFunc>( - numpy.get(), "minimum") && - RegisterUFunc>(numpy.get(), - "fmax") && - RegisterUFunc>(numpy.get(), - "fmin") && - RegisterUFunc>( - numpy.get(), "logical_and") && - RegisterUFunc>( - numpy.get(), "logical_or") && - RegisterUFunc>( - numpy.get(), "logical_xor") && - RegisterUFunc>( - numpy.get(), "logical_not") && - - // Floating point functions - RegisterUFunc>(numpy.get(), - "isfinite") && - RegisterUFunc>(numpy.get(), - "isinf") && - RegisterUFunc>(numpy.get(), - "isnan") && - RegisterUFunc>(numpy.get(), - "signbit") && - RegisterUFunc>( - numpy.get(), "copysign") && - RegisterUFunc>( - numpy.get(), "modf") && - RegisterUFunc>( - numpy.get(), "ldexp") && - RegisterUFunc>( - numpy.get(), "frexp") && - RegisterUFunc>(numpy.get(), - "floor") && - RegisterUFunc>(numpy.get(), - "ceil") && - RegisterUFunc>(numpy.get(), - "trunc") && - RegisterUFunc>( - numpy.get(), "nextafter"); - - return ok; -} - -StatusOr Bfloat16Dtype() { - if (npy_bfloat16 < 0) { - // Not yet initialized. We assume the GIL protects npy_bfloat16. - if (!Initialize()) { - return InternalError("Bfloat16 numpy type initialization failed."); - } - } - return py::object(reinterpret_cast(&PyBfloat16_Type), - /*is_borrowed=*/true); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py deleted file mode 100644 index 4c7321a5b7f..00000000000 --- a/tensorflow/compiler/xla/python/bfloat16_test.py +++ /dev/null @@ -1,440 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test cases for the bfloat16 Python type.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import copy -import itertools -import math - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -from tensorflow.compiler.xla.python import xla_client - -bfloat16 = xla_client.bfloat16 - - -def numpy_assert_allclose(a, b, **kwargs): - a = a.astype(np.float32) if a.dtype == bfloat16 else a - b = b.astype(np.float32) if b.dtype == bfloat16 else b - return np.testing.assert_allclose(a, b, **kwargs) - - -epsilon = float.fromhex("1.0p-7") - -# Values that should round trip exactly to float and back. -FLOAT_VALUES = [ - 0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon, - -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0, - float("inf"), - float("-inf"), - float("nan") -] - - -class Bfloat16Test(parameterized.TestCase): - """Tests the non-numpy Python methods of the bfloat16 type.""" - - def testRoundTripToFloat(self): - for v in FLOAT_VALUES: - np.testing.assert_equal(v, float(bfloat16(v))) - - def testRoundTripNumpyTypes(self): - for dtype in [np.float16, np.float32, np.float64]: - np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75)))) - np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5)))) - np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype)))) - np.testing.assert_equal( - np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype))) - - def testRoundTripToInt(self): - for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]: - self.assertEqual(v, int(bfloat16(v))) - - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(({ - "testcase_name": "_" + dtype.__name__, - "dtype": dtype - } for dtype in [bfloat16, np.float16, np.float32, np.float64])) - def testRoundTripToNumpy(self, dtype): - for v in FLOAT_VALUES: - np.testing.assert_equal(v, bfloat16(dtype(v))) - np.testing.assert_equal(v, dtype(bfloat16(dtype(v)))) - np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype)))) - if dtype != bfloat16: - np.testing.assert_equal( - np.array(FLOAT_VALUES, dtype), - bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype)) - - def testStr(self): - self.assertEqual("0", str(bfloat16(0.0))) - self.assertEqual("1", str(bfloat16(1.0))) - self.assertEqual("-3.5", str(bfloat16(-3.5))) - self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7")))) - self.assertEqual("inf", str(bfloat16(float("inf")))) - self.assertEqual("-inf", str(bfloat16(float("-inf")))) - self.assertEqual("nan", str(bfloat16(float("nan")))) - - def testRepr(self): - self.assertEqual("0", repr(bfloat16(0))) - self.assertEqual("1", repr(bfloat16(1))) - self.assertEqual("-3.5", repr(bfloat16(-3.5))) - self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7")))) - self.assertEqual("inf", repr(bfloat16(float("inf")))) - self.assertEqual("-inf", repr(bfloat16(float("-inf")))) - self.assertEqual("nan", repr(bfloat16(float("nan")))) - - def testHash(self): - self.assertEqual(0, hash(bfloat16(0.0))) - self.assertEqual(0x3f80, hash(bfloat16(1.0))) - self.assertEqual(0x7fc0, hash(bfloat16(float("nan")))) - - # Tests for Python operations - def testNegate(self): - for v in FLOAT_VALUES: - np.testing.assert_equal(-v, float(-bfloat16(v))) - - def testAdd(self): - np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0))) - np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0))) - np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1))) - np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5))) - np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25))) - np.testing.assert_equal( - float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25))) - np.testing.assert_equal( - float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25))) - self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan"))))) - - # Test type promotion against Numpy scalar values. - self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25))) - self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25))) - self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25))) - self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25))) - self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25))) - self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25))) - self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25))) - self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25))) - self.assertEqual(np.float32, - type(bfloat16(3.5) + np.array(2.25, np.float32))) - self.assertEqual(np.float32, - type(np.array(3.5, np.float32) + bfloat16(2.25))) - - def testSub(self): - np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0))) - np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0))) - np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1))) - np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5))) - np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25))) - np.testing.assert_equal( - float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf")))) - np.testing.assert_equal( - float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf")))) - self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan"))))) - - def testMul(self): - np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0))) - np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0))) - np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1))) - np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25))) - np.testing.assert_equal( - float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25))) - np.testing.assert_equal( - float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25))) - self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan"))))) - - def testDiv(self): - self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0)))) - np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0))) - np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1))) - np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2))) - np.testing.assert_equal( - float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25))) - np.testing.assert_equal( - float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25))) - self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan"))))) - - def testLess(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v < w, bfloat16(v) < bfloat16(w)) - - def testLessEqual(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w)) - - def testGreater(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v > w, bfloat16(v) > bfloat16(w)) - - def testGreaterEqual(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w)) - - def testEqual(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v == w, bfloat16(v) == bfloat16(w)) - - def testNotEqual(self): - for v in FLOAT_VALUES: - for w in FLOAT_VALUES: - self.assertEqual(v != w, bfloat16(v) != bfloat16(w)) - - def testNan(self): - a = np.isnan(bfloat16(float("nan"))) - self.assertTrue(a) - numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a])) - - a = np.array([bfloat16(1.34375), - bfloat16(1.4375), - bfloat16(float("nan"))], - dtype=bfloat16) - b = np.array( - [bfloat16(1.3359375), - bfloat16(1.4375), - bfloat16(float("nan"))], - dtype=bfloat16) - numpy_assert_allclose( - a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True) - - def testSort(self): - values_to_sort = np.float32(FLOAT_VALUES) - sorted_f32 = np.sort(values_to_sort) - sorted_bf16 = np.sort(values_to_sort.astype(bfloat16)) - np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16)) - - -BinaryOp = collections.namedtuple("BinaryOp", ["op"]) - -UNARY_UFUNCS = [ - np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign, - np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p, - np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan, - np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh, - np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc -] - -BINARY_UFUNCS = [ - np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2, - np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2, - np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign -] - -BINARY_PREDICATE_UFUNCS = [ - np.equal, np.not_equal, np.less, np.greater, np.less_equal, - np.greater_equal, np.logical_and, np.logical_or, np.logical_xor -] - - -class Bfloat16NumPyTest(parameterized.TestCase): - """Tests the NumPy integration of the bfloat16 type.""" - - def testDtype(self): - self.assertEqual(bfloat16, np.dtype(bfloat16)) - - def testDeepCopyDoesNotAlterHash(self): - # For context, see https://github.com/google/jax/issues/4651. If the hash - # value of the type descriptor is not initialized correctly, a deep copy - # can change the type hash. - dtype = np.dtype(bfloat16) - h = hash(dtype) - _ = copy.deepcopy(dtype) - self.assertEqual(h, hash(dtype)) - - def testArray(self): - x = np.array([[1, 2, 3]], dtype=bfloat16) - self.assertEqual(bfloat16, x.dtype) - self.assertEqual("[[1 2 3]]", str(x)) - np.testing.assert_equal(x, x) - numpy_assert_allclose(x, x) - self.assertTrue((x == x).all()) - - def testComparisons(self): - x = np.array([401408, 7, -32], dtype=np.float32) - bx = x.astype(bfloat16) - y = np.array([82432, 7, 0], dtype=np.float32) - by = y.astype(bfloat16) - np.testing.assert_equal(x == y, bx == by) - np.testing.assert_equal(x != y, bx != by) - np.testing.assert_equal(x < y, bx < by) - np.testing.assert_equal(x > y, bx > by) - np.testing.assert_equal(x <= y, bx <= by) - np.testing.assert_equal(x >= y, bx >= by) - - def testEqual2(self): - a = np.array([401408], bfloat16) - b = np.array([82432], bfloat16) - self.assertFalse(a.__eq__(b)) - - def testCasts(self): - for dtype in [ - np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, - np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32, - np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong - ]: - x = np.array([[1, 2, 3]], dtype=dtype) - y = x.astype(bfloat16) - z = y.astype(dtype) - self.assertTrue(np.all(x == y)) - self.assertEqual(bfloat16, y.dtype) - self.assertTrue(np.all(x == z)) - self.assertEqual(dtype, z.dtype) - - def testConformNumpyComplex(self): - for dtype in [np.complex64, np.complex128]: - x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype) - y_np = x.astype(np.float32) - y_tf = x.astype(bfloat16) - numpy_assert_allclose(y_np, y_tf, atol=2e-2) - - z_np = y_np.astype(dtype) - z_tf = y_tf.astype(dtype) - numpy_assert_allclose(z_np, z_tf, atol=2e-2) - - def testArange(self): - np.testing.assert_equal( - np.arange(100, dtype=np.float32).astype(bfloat16), - np.arange(100, dtype=bfloat16)) - np.testing.assert_equal( - np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), - np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) - np.testing.assert_equal( - np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), - np.arange(-0., -7., -0.25, dtype=bfloat16)) - np.testing.assert_equal( - np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), - np.arange(-16384., 16384., 64., dtype=bfloat16)) - - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters(({ - "testcase_name": "_" + op.__name__, - "op": op - } for op in UNARY_UFUNCS)) - def testUnaryUfunc(self, op): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7, 10).astype(bfloat16) - numpy_assert_allclose( - op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2) - - @parameterized.named_parameters(({ - "testcase_name": "_" + op.__name__, - "op": op - } for op in BINARY_UFUNCS)) - def testBinaryUfunc(self, op): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7, 10).astype(bfloat16) - y = rng.randn(4, 1, 7, 10).astype(bfloat16) - numpy_assert_allclose( - op(x, y).astype(np.float32), - op(x.astype(np.float32), y.astype(np.float32)), - rtol=1e-2) - - @parameterized.named_parameters(({ - "testcase_name": "_" + op.__name__, - "op": op - } for op in BINARY_PREDICATE_UFUNCS)) - def testBinaryPredicateUfunc(self, op): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7).astype(bfloat16) - y = rng.randn(4, 1, 7).astype(bfloat16) - np.testing.assert_equal( - op(x, y), op(x.astype(np.float32), y.astype(np.float32))) - - @parameterized.named_parameters(({ - "testcase_name": "_" + op.__name__, - "op": op - } for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not])) - def testPredicateUfunc(self, op): - rng = np.random.RandomState(seed=42) - shape = (3, 7, 10) - posinf_flips = rng.rand(*shape) < 0.1 - neginf_flips = rng.rand(*shape) < 0.1 - nan_flips = rng.rand(*shape) < 0.1 - vals = rng.randn(*shape) - vals = np.where(posinf_flips, np.inf, vals) - vals = np.where(neginf_flips, -np.inf, vals) - vals = np.where(nan_flips, np.nan, vals) - vals = vals.astype(bfloat16) - np.testing.assert_equal(op(vals), op(vals.astype(np.float32))) - - def testDivmod(self): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7).astype(bfloat16) - y = rng.randn(4, 1, 7).astype(bfloat16) - o1, o2 = np.divmod(x, y) - e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32)) - numpy_assert_allclose(o1, e1, rtol=1e-2) - numpy_assert_allclose(o2, e2, rtol=1e-2) - - def testModf(self): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7).astype(bfloat16) - o1, o2 = np.modf(x) - e1, e2 = np.modf(x.astype(np.float32)) - numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2) - numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2) - - def testLdexp(self): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7).astype(bfloat16) - y = rng.randint(-50, 50, (1, 7)) - numpy_assert_allclose( - np.ldexp(x, y).astype(np.float32), - np.ldexp(x.astype(np.float32), y), - rtol=1e-2, - atol=1e-6) - - def testFrexp(self): - rng = np.random.RandomState(seed=42) - x = rng.randn(3, 7).astype(bfloat16) - mant1, exp1 = np.frexp(x) - mant2, exp2 = np.frexp(x.astype(np.float32)) - np.testing.assert_equal(exp1, exp2) - numpy_assert_allclose(mant1, mant2, rtol=1e-2) - - def testNextAfter(self): - one = np.array(1., dtype=bfloat16) - two = np.array(2., dtype=bfloat16) - zero = np.array(0., dtype=bfloat16) - nan = np.array(np.nan, dtype=bfloat16) - np.testing.assert_equal(np.nextafter(one, two) - one, epsilon) - np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2) - np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True) - np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True) - np.testing.assert_equal(np.nextafter(one, one), one) - smallest_denormal = float.fromhex("1.0p-133") - np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) - np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) - for a, b in itertools.permutations([0., -0., nan], 2): - np.testing.assert_equal( - np.nextafter( - np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)), - np.nextafter( - np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16))) - - -if __name__ == "__main__": - absltest.main() diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 85252256657..fd603930d6c 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -214,11 +214,9 @@ StatusOr> StridesToLayout(absl::Span dims, } StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { - const se::Platform* platform = - device.local_device_state()->executor()->platform(); - if (platform->id() == se::host::kHostPlatformId) { + if (device.client()->platform_id() == kCpuId) { return kDLCPU; - } else if (platform->id() == se::cuda::kCudaPlatformId) { + } else if (device.client()->platform_id() == kGpuId) { return kDLGPU; } return InvalidArgument("Device %s cannot be used as a DLPack device.", @@ -228,7 +226,7 @@ StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); - context.device_id = device.local_device_id(); + context.device_id = device.local_hardware_id(); return context; } @@ -241,14 +239,14 @@ StatusOr DeviceForDLContext(const PjRtClient& client, "DLPack CPU device type mismatch with PjRtClient platform %s", client.platform_name()); } - return client.LookupLocalDevice(context.device_id); + return client.LookupAddressableDevice(context.device_id); case kDLGPU: if (client.platform_id() != kGpuId) { return InvalidArgument( "DLPack GPU device type mismatch with PjRtClient platform %s", client.platform_name()); } - return client.LookupLocalDevice(context.device_id); + return client.LookupAddressableDevice(context.device_id); default: return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); @@ -297,7 +295,7 @@ StatusOr BufferToDLPackManagedTensor(py::handle py_buffer, pack->tensor.manager_ctx = pack.get(); pack->tensor.deleter = DLPackTensorDeleter; TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); - dt.ctx.device_id = buffer->buffer()->device()->local_device_id(); + dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id(); dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 3c0f9750f7f..df4bc3025f1 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl( callback_ = callback; max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; for (const auto& client : clients) { - for (const auto& device : client->devices()) { - devices_.push_back(device.get()); + for (auto device : client->devices()) { + devices_.push_back(device); } } CHECK_GT(devices_.size(), 0); @@ -342,11 +342,7 @@ StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( const PjRtDevice* device, const Shape& shape) { std::shared_ptr literal_shared; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); - TF_ASSIGN_OR_RETURN(Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); + TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape)); return absl::make_unique(std::move(literal)); } diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 1f39266a989..dfa312c4592 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -86,8 +86,8 @@ StatusOr PyBuffer::UnsafeBufferPointer() const { } StatusOr PyBuffer::CudaArrayInterface() const { - if (buffer_->device()->local_device_state()->executor()->platform_kind() != - se::PlatformKind::kCuda) { + // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs. + if (buffer_->client()->platform_id() != kGpuId) { return InvalidArgument( "__cuda_array_interface__ is only defined for NVidia GPU buffers."); } diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index d42bbdca154..a5638f973df 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr pjrt_client) std::vector> PyClient::Devices() { std::vector> devices; - devices.reserve(pjrt_client_->devices().size()); - for (const auto& device : pjrt_client_->devices()) { - devices.push_back(WrapWithClient(shared_from_this(), device.get())); + auto span = pjrt_client_->devices(); + devices.reserve(span.size()); + for (PjRtDevice* device : span) { + devices.push_back(WrapWithClient(shared_from_this(), device)); } return devices; } @@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) { result[r].resize(num_partitions); for (int p = 0; p < num_partitions; ++p) { int device_id = device_assignment(r, p); - auto iter = pjrt_client_->id_to_device().find(device_id); - CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; - result[r][p] = WrapWithClient(shared_from_this(), iter->second); + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + pjrt_client_->LookupDevice(device_id)); + result[r][p] = WrapWithClient(shared_from_this(), device); } } return result; @@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); - auto iter = pjrt_client_->id_to_device().find(device_id); - CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; - result.push_back(WrapWithClient(shared_from_this(), iter->second)); + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + pjrt_client_->LookupDevice(device_id)); + result.push_back(WrapWithClient(shared_from_this(), device)); } return result; } @@ -95,8 +96,9 @@ StatusOr> PyClient::BufferFromPyval( device = pjrt_client_->local_devices().front(); } CHECK(device != nullptr); - auto iter = pjrt_client_->id_to_device().find(device->id()); - if (iter->second != device) { + TF_ASSIGN_OR_RETURN(PjRtDevice * found_device, + pjrt_client_->LookupDevice(device->id())); + if (found_device != device) { return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", device->DebugString(), pjrt_client_->platform_name()); diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index 37f5333ea1c..158171b83c7 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this { const std::string& platform_name() const { return pjrt_client_->platform_name(); } - int local_device_count() const { return pjrt_client_->local_device_count(); } + int addressable_device_count() const { + return pjrt_client_->addressable_device_count(); + } int device_count() const { return pjrt_client_->device_count(); } int host_id() const { return pjrt_client_->host_id(); } diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc index bf0bb1a8d93..d9a7a05916a 100644 --- a/tensorflow/compiler/xla/python/pytree.cc +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -32,6 +32,7 @@ limitations under the License. #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" +#include "tensorflow/compiler/xla/python/types.h" namespace xla { @@ -106,59 +107,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } } -void PyTreeDef::FlattenInto(py::handle handle, - std::vector& leaves) { +void PyTreeDef::FlattenInto(py::handle handle, std::vector& leaves, + absl::optional leaf_predicate) { Node node; int start_num_nodes = traversal_.size(); int start_num_leaves = leaves.size(); - node.kind = GetKind(handle, &node.custom); - if (node.kind == Kind::kNone) { - // Nothing to do. - } else if (node.kind == Kind::kTuple) { - py::tuple tuple = py::reinterpret_borrow(handle); - node.arity = tuple.size(); - for (py::handle entry : tuple) { - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kList) { - py::list list = py::reinterpret_borrow(handle); - node.arity = list.size(); - for (py::handle entry : list) { - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kDict) { - py::dict dict = py::reinterpret_borrow(handle); - py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); - if (PyList_Sort(keys.ptr())) { - throw std::runtime_error("Dictionary key sort failed."); - } - for (py::handle key : keys) { - FlattenInto(dict[key], leaves); - } - node.arity = dict.size(); - node.node_data = std::move(keys); - } else if (node.kind == Kind::kCustom) { - py::tuple out = py::cast(node.custom->to_iterable(handle)); - if (out.size() != 2) { - throw std::runtime_error( - "PyTree custom to_iterable function should return a pair"); - } - node.node_data = out[1]; - node.arity = 0; - for (py::handle entry : py::cast(out[0])) { - ++node.arity; - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kNamedTuple) { - py::tuple tuple = py::reinterpret_borrow(handle); - node.arity = tuple.size(); - node.node_data = py::reinterpret_borrow(tuple.get_type()); - for (py::handle entry : tuple) { - FlattenInto(entry, leaves); - } + if (leaf_predicate && (*leaf_predicate)(handle).cast()) { + leaves.push_back(py::reinterpret_borrow(handle)); } else { - assert(node.kind == Kind::kLeaf); - leaves.push_back(pybind11::reinterpret_borrow(handle)); + node.kind = GetKind(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves](py::handle child) { + FlattenInto(child, leaves, leaf_predicate); + }; + if (node.kind == Kind::kNone) { + // Nothing to do. + } else if (node.kind == Kind::kTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + for (py::handle entry : tuple) { + recurse(entry); + } + } else if (node.kind == Kind::kList) { + py::list list = py::reinterpret_borrow(handle); + node.arity = list.size(); + for (py::handle entry : list) { + recurse(entry); + } + } else if (node.kind == Kind::kDict) { + py::dict dict = py::reinterpret_borrow(handle); + py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); + if (PyList_Sort(keys.ptr())) { + throw std::runtime_error("Dictionary key sort failed."); + } + for (py::handle key : keys) { + recurse(dict[key]); + } + node.arity = dict.size(); + node.node_data = std::move(keys); + } else if (node.kind == Kind::kCustom) { + py::tuple out = py::cast(node.custom->to_iterable(handle)); + if (out.size() != 2) { + throw std::runtime_error( + "PyTree custom to_iterable function should return a pair"); + } + node.node_data = out[1]; + node.arity = 0; + for (py::handle entry : py::cast(out[0])) { + ++node.arity; + recurse(entry); + } + } else if (node.kind == Kind::kNamedTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + node.node_data = py::reinterpret_borrow(tuple.get_type()); + for (py::handle entry : tuple) { + recurse(entry); + } + } else { + assert(node.kind == Kind::kLeaf); + leaves.push_back(py::reinterpret_borrow(handle)); + } } node.num_nodes = traversal_.size() - start_num_nodes + 1; node.num_leaves = leaves.size() - start_num_leaves; @@ -166,10 +174,10 @@ void PyTreeDef::FlattenInto(py::handle handle, } /*static*/ std::pair, std::unique_ptr> -PyTreeDef::Flatten(py::handle x) { +PyTreeDef::Flatten(py::handle x, absl::optional leaf_predicate) { std::vector leaves; auto tree = absl::make_unique(); - tree->FlattenInto(x, leaves); + tree->FlattenInto(x, leaves, leaf_predicate); return std::make_pair(std::move(leaves), std::move(tree)); } @@ -618,7 +626,8 @@ std::string PyTreeDef::ToString() const { void BuildPytreeSubmodule(py::module& m) { py::module pytree = m.def_submodule("pytree", "Python tree library"); - pytree.def("flatten", &PyTreeDef::Flatten); + pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"), + py::arg("leaf_predicate") = absl::nullopt); pytree.def("tuple", &PyTreeDef::Tuple); pytree.def("all_leaves", &PyTreeDef::AllLeaves); diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h index 69cd93a7d08..c0a99a1dff3 100644 --- a/tensorflow/compiler/xla/python/pytree.h +++ b/tensorflow/compiler/xla/python/pytree.h @@ -85,11 +85,13 @@ class PyTreeDef { // Flattens a Pytree into a list of leaves and a PyTreeDef. static std::pair, std::unique_ptr> - Flatten(pybind11::handle x); + Flatten(pybind11::handle x, + absl::optional leaf_predicate = absl::nullopt); // Recursive helper used to implement Flatten(). - void FlattenInto(pybind11::handle handle, - std::vector& leaves); + void FlattenInto( + pybind11::handle handle, std::vector& leaves, + absl::optional leaf_predicate = absl::nullopt); // Tests whether the given list is a flat list of leaves. static bool AllLeaves(const pybind11::iterable& x); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 9d98d0cf654..28a491c0326 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core/framework:allocator", + "//tensorflow/core/platform:casts", "//tensorflow/core/platform:env", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index c6a748068d4..a9aa218ca6f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,8 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::PjRtDevice(id, /*local_device_state=*/nullptr, - /*device_kind=*/"Cloud TPU", host_id), + : xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr, + /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} @@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable( << "Inserting duplicate replica:" << replica; executables_[replica] = client_->driver()->LoadProgram(device_id, compiled_program.get(), {}); - addressable_device_logical_ids_.emplace_back(replica, partition); + local_logical_device_ids_.emplace_back(replica, partition); local_devices_.push_back(device); } } @@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices( // long time and we want all cores to be scheduled in parallel. thread_pool->Schedule([this, i, argument_handles, &results, &results_lock, &execute_semaphore]() { - const int replica = addressable_device_logical_ids_[i].first; - const int partition = addressable_device_logical_ids_[i].second; + const int replica = local_logical_device_ids_[i].first; + const int partition = local_logical_device_ids_[i].second; RunId run_id; auto result = ExecuteHelper(argument_handles, argument_handles[i], replica, partition, run_id); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 20c2f749a75..89dca53bbb6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -32,13 +32,14 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/threadpool.h" namespace xla { constexpr char kTpuPlatform[] = "tpu"; -class TpuDevice : public PjRtDevice { +class TpuDevice : public PjRtStreamExecutorDevice { public: TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip); @@ -298,9 +299,8 @@ class PyTpuExecutable { return device_assignment_; } - const std::vector>& addressable_device_logical_ids() - const { - return addressable_device_logical_ids_; + const std::vector>& local_logical_device_ids() const { + return local_logical_device_ids_; } const std::vector>& local_devices() const { @@ -341,14 +341,16 @@ class PyTpuExecutable { // The replica and partition indices of device_assignment_ to be run by this // client. On single-host platforms without partitioning, this is all replicas - // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the - // case on multi-host platforms. If there are 4 replicas and 2 partitions on a - // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. - std::vector> addressable_device_logical_ids_; + // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case + // on multi-host platforms. + // If there are 4 replicas and 2 partitions on a single host platform, size of + // local_logical_device_ids_ is 4*2 = 8. + std::vector> local_logical_device_ids_; - // local_devices_[i] is the Device to which addressable_device_logical_ids_[i] - // is assigned. shared_ptrs instead of unique_ptrs to play well with the - // Python bindings (see xla.cc). + // local_devices_[i] is the Device to which local_logical_device_ids_[i] is + // assigned. + // shared_ptrs instead of unique_ptrs to play well with the Python bindings + // (see xla.cc). std::vector> local_devices_; xla::Shape result_shape_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 0562ff230e1..a9fd70b6475 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_(m, "TpuExecutable") .def("local_logical_device_ids", - &PyTpuExecutable::addressable_device_logical_ids) + &PyTpuExecutable::local_logical_device_ids) .def("local_devices", &PyTpuExecutable::local_devices) .def_property_readonly("client", &PyTpuExecutable::client) .def("size_of_generated_code_in_bytes", diff --git a/tensorflow/compiler/xla/python/types.cc b/tensorflow/compiler/xla/python/types.cc index 882b38d57c6..40a3e589ba5 100644 --- a/tensorflow/compiler/xla/python/types.cc +++ b/tensorflow/compiler/xla/python/types.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/python/types.h" #include "absl/container/flat_hash_map.h" -#include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/python/lib/core/bfloat16.h" namespace xla { @@ -81,8 +81,8 @@ xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { case U64: return py::dtype::of(); case BF16: { - TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype()); - return py::dtype::from_args(bfloat16); + py::handle bfloat16(tensorflow::Bfloat16Dtype()); + return py::dtype::from_args(py::reinterpret_borrow(bfloat16)); } case F16: return py::dtype("e"); // PEP 3118 code for "float16 @@ -237,10 +237,11 @@ StatusOr LiteralToPython(std::shared_ptr literal) { // We requested an array of uint16 since NumPy doesn't know how // to produce our custom bfloat16 type. Reinterpret the array as bfloat16 // before handing it back to the caller. - TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype()); + py::handle bfloat16(tensorflow::Bfloat16Dtype()); + bfloat16.inc_ref(); array = py::reinterpret_steal( PyArray_View(reinterpret_cast(array.ptr()), - reinterpret_cast(bfloat16.release().ptr()), + reinterpret_cast(bfloat16.ptr()), static_cast(nullptr))); } return array; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index dee1b14b90f..4060df2b600 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/interpreter_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/tpu_client.h" -#include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/jax_jit.h" #include "tensorflow/compiler/xla/python/ops.h" @@ -59,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/python/lib/core/bfloat16.h" #include "tensorflow/stream_executor/platform.h" namespace xla { @@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) { throw std::runtime_error("Unable to initialize Numpy API"); } + CHECK(tensorflow::RegisterNumpyBfloat16()); + // Types py::enum_(m, "PrimitiveType") .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) @@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) { .value("OPAQUE_TYPE", OPAQUE_TYPE) .value("TOKEN", TOKEN); - m.def("bfloat16_dtype", Bfloat16Dtype); + m.def("bfloat16_dtype", + []() { return py::handle(tensorflow::Bfloat16Dtype()); }); // Must be before PyClient.compile. BuildXlaCompilerSubmodule(m); @@ -149,7 +152,10 @@ PYBIND11_MODULE(xla_extension, m) { .def_property_readonly("host_id", &PjRtDevice::host_id, "Integer ID of this device's host.\n\n" "This is always 0 except on multi-host platforms.") - .def_property_readonly("platform", &PjRtDevice::platform_name) + .def_property_readonly("platform", + [](const PjRtDevice& device) { + return device.client()->platform_name(); + }) .def_property_readonly("device_kind", &PjRtDevice::device_kind) .def_property_readonly( "client", @@ -234,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) { py::class_> py_local_client(m, "Client"); py_local_client.def_property_readonly("platform", &PyClient::platform_name) .def("device_count", &PyClient::device_count) - .def("local_device_count", &PyClient::local_device_count) + .def("local_device_count", &PyClient::addressable_device_count) .def("devices", &PyClient::Devices) .def("local_devices", &PyClient::LocalDevices) .def("host_id", &PyClient::host_id) @@ -381,10 +387,10 @@ PYBIND11_MODULE(xla_extension, m) { [](PyExecutable* exec) { auto span = exec->addressable_device_logical_ids(); // Not on dispatch critical path, so ok to have heap allocation. - std::vector> addressable_device_logical_ids; - addressable_device_logical_ids.reserve(span.size()); + std::vector> addressable_device_logic_ids; + addressable_device_logic_ids.reserve(span.size()); for (const auto& logical_device_id : span) { - addressable_device_logical_ids.push_back(std::make_pair( + addressable_device_logic_ids.push_back(std::make_pair( logical_device_id.replica, logical_device_id.partition)); } }) diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 9e169fd8210..623b8262178 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool.h" namespace xla { @@ -158,6 +159,18 @@ class AotCompilationMetadata { // platform. class Compiler { public: + struct CompileOptions { + // If device_allocator is not null, the compiler may use it to allocate temp + // space on the device for use during compilation. For example, the + // compiler may allocate buffers on the device and then run variants of a + // given algorithm over those buffers, to see which variant is fastest. Any + // space allocated will be deallocated before the compilation returns. + se::DeviceMemoryAllocator* device_allocator = nullptr; + + // An optional thread pool for parallel compilation. + tensorflow::thread::ThreadPool* thread_pool = nullptr; + }; + virtual ~Compiler() {} // Returns the ID of the platform that this compiler targets. @@ -165,31 +178,24 @@ class Compiler { // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. - // - // If device_allocator is not null, the compiler may use it to allocate temp - // space on the device for use during compilation. For example, the compiler - // may allocate buffers on the device and then run variants of a given - // algorithm over those buffers, to see which variant is fastest. Any space - // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) = 0; + const CompileOptions& options) = 0; + StatusOr> RunHloPasses( + std::unique_ptr module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator) { + return RunHloPasses(std::move(module), executor, + CompileOptions{device_allocator}); + } // Runs HLO passes to optimize the given HloModule, perform scheduling and // buffer assignment, returns the optimized module and the buffer assignments. // This interface is intentionally narrow. - // - // If device_allocator is not null, the compiler may use it to allocate temp - // space on the device for use during compilation. For example, the compiler - // may allocate buffers on the device and then run variants of a given - // algorithm over those buffers, to see which variant is fastest. Any space - // allocated should be deallocated before this function returns. virtual StatusOr< std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr module, - se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator, - bool optimize) { + se::StreamExecutor* executor, bool optimize, + const CompileOptions& options) { return Unimplemented("This compiler does not support this method"); } @@ -201,24 +207,33 @@ class Compiler { // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. - // - // device_allocator is optional; see RunHloPasses. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) = 0; + const CompileOptions& options) = 0; + StatusOr> RunBackend( + std::unique_ptr module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator) { + return RunBackend(std::move(module), executor, + CompileOptions{device_allocator}); + } // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. // - // device_allocator is optional; see RunHloPasses. - // // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) = 0; + const CompileOptions& options) = 0; + StatusOr>> Compile( + std::unique_ptr module_group, + std::vector> stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + return Compile(std::move(module_group), stream_exec, + CompileOptions{device_allocator}); + } // Returns the backend configurations that the backend will consider for the // given HLO. Returns no configurations if the backend does not support diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ca67fe66994..ee1f573c867 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -553,7 +553,7 @@ Status CreateHloProfilingArtifacts( StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, - se::DeviceMemoryAllocator* /*device_allocator*/) { + const CompileOptions& /*options*/) { std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( CompilerTargetOptions(module->config()), @@ -566,12 +566,13 @@ StatusOr> CpuCompiler::RunHloPasses( StatusOr< std::tuple, std::unique_ptr>> -CpuCompiler::RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator, bool optimize) { +CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr module, + se::StreamExecutor* executor, + bool optimize, + const CompileOptions& options) { if (optimize) { - TF_ASSIGN_OR_RETURN( - module, RunHloPasses(std::move(module), executor, device_allocator)); + TF_ASSIGN_OR_RETURN(module, + RunHloPasses(std::move(module), executor, options)); } // Select an order for emitting the HLO instructions for each computation. @@ -632,7 +633,7 @@ struct OrcJITPostCompilationHook { StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* /*device_allocator*/) { + const CompileOptions& options) { VLOG(1) << "Compiling: " << module->name(); XLA_SCOPED_LOGGING_TIMER( absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 5c056fcacaa..9f5e6a92909 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -134,18 +134,17 @@ class CpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr< std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr module, - se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator, - bool optimize) override; + se::StreamExecutor* executor, bool optimize, + const CompileOptions& options) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4ca094e82a5..f9bacdd8145 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -440,7 +440,7 @@ filegroup( name = "nccl_collective_thunk_src", srcs = if_nccl( ["nccl_collective_thunk.cc"], - ["dummy_collective_thunk.cc"], + ["nccl_collective_thunk_dummy.cc"], ), ) @@ -448,7 +448,7 @@ tf_cuda_library( name = "nccl_collective_thunk", srcs = if_cuda_or_rocm( [":nccl_collective_thunk_src"], - ["dummy_collective_thunk.cc"], + ["nccl_collective_thunk_dummy.cc"], ), hdrs = ["nccl_collective_thunk.h"], deps = [ @@ -480,7 +480,7 @@ filegroup( name = "nccl_all_gather_thunk_src", srcs = if_nccl( ["nccl_all_gather_thunk.cc"], - ["dummy_all_gather_thunk.cc"], + ["nccl_all_gather_thunk_dummy.cc"], ), ) @@ -488,7 +488,7 @@ tf_cuda_library( name = "nccl_all_gather_thunk", srcs = if_cuda_or_rocm( [":nccl_all_gather_thunk_src"], - ["dummy_all_gather_thunk.cc"], + ["nccl_all_gather_thunk_dummy.cc"], ), hdrs = ["nccl_all_gather_thunk.h"], deps = [ @@ -520,7 +520,7 @@ filegroup( name = "nccl_all_reduce_thunk_src", srcs = if_nccl( ["nccl_all_reduce_thunk.cc"], - ["dummy_all_reduce_thunk.cc"], + ["nccl_all_reduce_thunk_dummy.cc"], ), ) @@ -528,7 +528,7 @@ tf_cuda_library( name = "nccl_all_reduce_thunk", srcs = if_cuda_or_rocm( [":nccl_all_reduce_thunk_src"], - ["dummy_all_reduce_thunk.cc"], + ["nccl_all_reduce_thunk_dummy.cc"], ), hdrs = ["nccl_all_reduce_thunk.h"], deps = [ @@ -560,7 +560,7 @@ filegroup( name = "nccl_all_to_all_thunk_src", srcs = if_nccl( ["nccl_all_to_all_thunk.cc"], - ["dummy_all_to_all_thunk.cc"], + ["nccl_all_to_all_thunk_dummy.cc"], ), ) @@ -568,7 +568,7 @@ tf_cuda_library( name = "nccl_all_to_all_thunk", srcs = if_cuda_or_rocm( [":nccl_all_to_all_thunk_src"], - ["dummy_all_to_all_thunk.cc"], + ["nccl_all_to_all_thunk_dummy.cc"], ), hdrs = ["nccl_all_to_all_thunk.h"], deps = [ @@ -600,7 +600,7 @@ filegroup( name = "nccl_test_utils_src", srcs = if_nccl( ["nccl_test_utils.cc"], - ["dummy_nccl_test_utils.cc"], + ["nccl_test_utils_dummy.cc"], ), ) @@ -608,7 +608,7 @@ tf_cuda_library( name = "nccl_test_utils", srcs = if_cuda_or_rocm( [":nccl_test_utils_src"], - ["dummy_nccl_test_utils.cc"], + ["nccl_test_utils_dummy.cc"], ), hdrs = ["nccl_test_utils.h"], deps = [ @@ -1452,7 +1452,11 @@ cc_library( "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@llvm-project//llvm:AsmParser", + "@llvm-project//llvm:BitReader", + "@llvm-project//llvm:BitWriter", "@llvm-project//llvm:Core", + "@llvm-project//llvm:TransformUtils", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", ], @@ -1517,7 +1521,7 @@ cc_library( "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/cuda:cuda_diagnostics", "//tensorflow/stream_executor/gpu:asm_compiler", - ]), + ]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"], ) cc_library( diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 974db02b1b3..f6409b476b5 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -108,12 +108,17 @@ StatusOr>> AMDGPUCompiler::CompileTargetBinary(const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version, - se::StreamExecutor* stream_exec) { + se::StreamExecutor* stream_exec, + bool relocatable) { if (rocdl_dir_.empty()) { // Compute rocdl_dir_ just once and cache it in this member. rocdl_dir_ = GetROCDLDir(module->config()); } + if (relocatable) { + return Unimplemented("relocatable target binary is not implemented"); + } + std::vector hsaco; { XLA_SCOPED_LOGGING_TIMER( diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h index acc5e021e3d..36318badeef 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h @@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler { StatusOr>> CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; + GpuVersion gpu_version, se::StreamExecutor* stream_exec, + bool relocatable) override; private: // The parent directory of ROCm-Device-Libs IR libraries. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3eee882cdf0..b5e3c14c791 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -24,11 +24,15 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" +#include "llvm/Transforms/Utils/SplitModule.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/protobuf_util.h" @@ -114,11 +118,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/env_var.h" @@ -470,14 +476,14 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( StatusOr> GpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { // We dump the post-optimization HLO in RunBackend so no need to dump it here. XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); tensorflow::profiler::TraceMe activity( [&] { return absl::StrCat("HLO Transforms:", module->name()); }, tensorflow::profiler::TraceMeLevel::kInfo); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, options.device_allocator)); TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); @@ -494,10 +500,10 @@ StatusOr< std::tuple, std::unique_ptr>> GpuCompiler::RunHloPassesAndBufferAssignement( std::unique_ptr hlo_module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator, bool optimize) { + bool optimize, const CompileOptions& options) { if (optimize) { - TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module), - executor, device_allocator)); + TF_ASSIGN_OR_RETURN(hlo_module, + RunHloPasses(std::move(hlo_module), executor, options)); } std::unique_ptr stream_assignment = @@ -641,24 +647,133 @@ static Status CompileModuleToLlvmIrImpl( return Status::OK(); } +StatusOr>> +GpuCompiler::CompileToTargetBinary(const HloModule& module, + std::unique_ptr llvm_module, + se::StreamExecutor* stream_exec, + const CompileOptions& options) { + using BackendCompileResult = std::pair>; + + const auto compile_single_module = + [this, stream_exec, &module]( + llvm::Module* llvm_module, + bool relocatable) -> StatusOr { + { + XLA_SCOPED_LOGGING_TIMER( + "GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR " + "lowering. " + "Rerun with --xla_dump_to to get the IR and looks for files " + "with " + "name containing: *" + << FilenameFor(module, "", "") << "*"; + } + GpuVersion gpu_version = GetGpuVersion(stream_exec); + return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec, + relocatable); + }; + + tensorflow::thread::ThreadPool* thread_pool = options.thread_pool; + if (!thread_pool) { + return compile_single_module(llvm_module.get(), /*relocatable=*/false); + } + + // Test whether LinkModules is supported. + if (this->LinkModules(stream_exec, {}).status().code() == + tensorflow::error::Code::UNIMPLEMENTED) { + return compile_single_module(llvm_module.get(), /*relocatable=*/false); + } + + std::vector> llvm_modules; + int num_functions = 0; + for (llvm::Function& func : llvm_module->functions()) { + if (!func.isDeclaration() && + func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) { + num_functions++; + } + } + + llvm::SplitModule( + std::move(llvm_module), + std::max( + 1, std::min(thread_pool->NumThreads(), num_functions)), + [&](std::unique_ptr module) { + llvm_modules.push_back(std::move(module)); + }, + /*PreserveLocals=*/true); + + std::vector> compile_results( + llvm_modules.size()); + tensorflow::BlockingCounter counter(llvm_modules.size()); + for (int i = 0; i < llvm_modules.size(); i++) { + thread_pool->Schedule([&compile_results, compile_single_module, i, + &llvm_modules, &counter] { + llvm::Module* original_module = llvm_modules[i].get(); + llvm::LLVMContext context; + std::string buffer; + llvm::raw_string_ostream error(buffer); + llvm::DiagnosticPrinterRawOStream printer(error); + auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, + void* Context) { + auto printer = static_cast(Context); + diag_info.print(*printer); + }; + context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); + + std::unique_ptr new_llvm_module; + { + std::string ir; + { + llvm::raw_string_ostream os(ir); + original_module->print(os, nullptr); + } + llvm::SMDiagnostic err; + new_llvm_module = llvm::parseAssemblyString(ir, err, context); + } + + compile_results[i] = + compile_single_module(new_llvm_module.get(), /*relocatable=*/true); + counter.DecrementCount(); + }); + } + counter.Wait(); + + std::string ptx_snippets; + std::vector> submodule_compile_results; + for (auto& maybe_result : compile_results) { + TF_ASSIGN_OR_RETURN(auto result, maybe_result); + if (result.second.empty()) { + continue; + } + ptx_snippets += result.first; + ptx_snippets += "\n"; + submodule_compile_results.push_back(result.second); + } + + TF_ASSIGN_OR_RETURN( + std::vector backend_result, + this->LinkModules(stream_exec, std::move(submodule_compile_results))); + + return std::make_pair(ptx_snippets, backend_result); +} + StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); auto slow_compile_alarm = SlowCompilationAlarm(); TF_RET_CHECK(stream_exec != nullptr); llvm::LLVMContext llvm_context; - std::string buffer; - llvm::raw_string_ostream error(buffer); - llvm::DiagnosticPrinterRawOStream printer(error); - auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, - void* Context) { - auto printer = static_cast(Context); - diag_info.print(*printer); - }; - llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); GpuDeviceInfo gpu_device_info; gpu_device_info.threads_per_block_limit = @@ -724,34 +839,16 @@ StatusOr> GpuCompiler::RunBackend( llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false); - { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); - - std::string err; - llvm::raw_string_ostream err_stream(err); - - // verifyModule() returns true if the module is broken. - TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream)) - << "Invalid LLVM IR before optimizations:\n" - << err_stream.str() - << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_to to get the IR and looks for files with " - "name containing: *" - << FilenameFor(*module, "", "") << "*"; - } - - GpuVersion gpu_version = GetGpuVersion(stream_exec); - using BackendCompileResult = std::pair>; TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, - CompileTargetBinary(module.get(), llvm_module.get(), - gpu_version, stream_exec)); - + CompileToTargetBinary(*module, std::move(llvm_module), + stream_exec, options)); if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_schedule", thunk_schedule->ToString()); } + GpuVersion gpu_version = GetGpuVersion(stream_exec); auto* gpu_executable = new GpuExecutable( backend_result.first, backend_result.second, gpu_version, std::move(thunk_schedule), std::move(module), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 824d7404ebe..1d42976e352 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -53,14 +53,13 @@ class GpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr< std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr hlo_module, - se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator, - bool optimize) override; + se::StreamExecutor* executor, bool optimize, + const CompileOptions& options) override; Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, @@ -84,19 +83,23 @@ class GpuCompiler : public LLVMCompiler { virtual StatusOr>> CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, - se::StreamExecutor* stream_exec) = 0; + GpuVersion gpu_version, se::StreamExecutor* stream_exec, + bool relocatable) = 0; Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; + StatusOr>> CompileToTargetBinary( + const HloModule& module, std::unique_ptr llvm_module, + se::StreamExecutor* stream_exec, const CompileOptions& options); + se::Platform::Id PlatformId() const override { return platform_id_; } HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { @@ -116,6 +119,12 @@ class GpuCompiler : public LLVMCompiler { } private: + virtual StatusOr> LinkModules( + se::StreamExecutor* stream_exec, + std::vector> modules) { + return Unimplemented("LinkModules is not implemented."); + } + se::Platform::Id platform_id_; // The triple that represents our target. diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_gather_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc similarity index 100% rename from tensorflow/compiler/xla/service/gpu/dummy_all_gather_thunk.cc rename to tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk_dummy.cc diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc similarity index 100% rename from tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc rename to tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk_dummy.cc diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_to_all_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc similarity index 100% rename from tensorflow/compiler/xla/service/gpu/dummy_all_to_all_thunk.cc rename to tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk_dummy.cc diff --git a/tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc similarity index 100% rename from tensorflow/compiler/xla/service/gpu/dummy_collective_thunk.cc rename to tensorflow/compiler/xla/service/gpu/nccl_collective_thunk_dummy.cc diff --git a/tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_test_utils_dummy.cc similarity index 100% rename from tensorflow/compiler/xla/service/gpu/dummy_nccl_test_utils.cc rename to tensorflow/compiler/xla/service/gpu/nccl_test_utils_dummy.cc diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 3225cd2d5a5..070b8a1fcfb 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h" +#include "tensorflow/stream_executor/gpu/gpu_driver.h" namespace xla { namespace gpu { @@ -299,7 +300,8 @@ StatusOr>> NVPTXCompiler::CompileTargetBinary(const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version, - se::StreamExecutor* stream_exec) { + se::StreamExecutor* stream_exec, + bool relocatable) { std::pair compute_capability = absl::get>(gpu_version); @@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module, std::vector cubin = CompileGpuAsmOrGetCachedResult( stream_exec, ptx, compute_capability.first, compute_capability.second, - module->config()); + module->config(), relocatable); return std::pair>(std::move(ptx), std::move(cubin)); @@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module, std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( se::StreamExecutor* stream_exec, const string& ptx, int cc_major, - int cc_minor, const HloModuleConfig& hlo_module_config) { + int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult"); tensorflow::profiler::TraceMe activity( "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo); @@ -361,7 +363,7 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( tensorflow::mutex_lock lock(mutex_); std::tie(iter, inserted) = compilation_cache_.emplace( std::piecewise_construct, - std::forward_as_tuple(ptx, cc_major, cc_minor), + std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable), std::forward_as_tuple()); cache_ptx = &iter->first.ptx; cache_value = &iter->second; @@ -375,9 +377,13 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = - se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(), - PtxOptsFromConfig(hlo_module_config)); + auto ptxas_config = PtxOptsFromConfig(hlo_module_config); + if (relocatable) { + ptxas_config.extra_flags.push_back("-c"); + } + StatusOr> maybe_cubin = se::CompileGpuAsm( + stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config); + if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() @@ -445,5 +451,17 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( return cache_value->cubin_data; } +StatusOr> NVPTXCompiler::LinkModules( + se::StreamExecutor* stream_exec, std::vector> modules) { + std::vector images; + images.reserve(modules.size()); + for (auto& module : modules) { + images.push_back({"", std::move(module)}); + } + return LinkGpuAsm(static_cast( + stream_exec->implementation()->GpuContextHack()), + images); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 3e19b35af19..5c78b48b9c6 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler { StatusOr>> CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; + GpuVersion gpu_version, se::StreamExecutor* stream_exec, + bool relocatable) override; private: + StatusOr> LinkModules( + se::StreamExecutor* stream_exec, + std::vector> modules) override; + tensorflow::mutex mutex_; // When compiling an HLO module, we need to find a path to the nvvm libdevice @@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler { // compiled cubin. If compilation was unsuccessful, returns an empty vector. std::vector CompileGpuAsmOrGetCachedResult( se::StreamExecutor* stream_exec, const string& ptx, int cc_major, - int cc_minor, const HloModuleConfig& hlo_module_config); + int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for @@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler { // If compiling the ptx fails, we return an empty cubin, cross our fingers, // and leave compilation up to the driver. struct CompilationCacheKey { - CompilationCacheKey(std::string ptx, int cc_major, int cc_minor) - : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {} + CompilationCacheKey(std::string ptx, int cc_major, int cc_minor, + bool relocatable) + : ptx(std::move(ptx)), + cc_major(cc_major), + cc_minor(cc_minor), + relocatable(relocatable) {} string ptx; int cc_major; int cc_minor; + bool relocatable; }; struct CompilationCacheHash { size_t operator()(const CompilationCacheKey& key) const { return tensorflow::Hash64Combine( - tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major), - key.cc_minor); + tensorflow::Hash64Combine( + tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), + key.cc_major), + key.cc_minor), + key.relocatable); } }; struct CompilationCacheEq { size_t operator()(const CompilationCacheKey& a, const CompilationCacheKey& b) const { return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && - a.ptx == b.ptx; + a.ptx == b.ptx && a.relocatable == b.relocatable; } }; struct CompilationCacheValue { diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3f3e74dbb62..8b0a046ffa9 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -95,7 +95,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { StatusOr> InterpreterCompiler::RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, - se::DeviceMemoryAllocator* /*device_allocator*/) { + const CompileOptions& /*options*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); return std::move(hlo_module); @@ -103,7 +103,7 @@ StatusOr> InterpreterCompiler::RunHloPasses( StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* /*device_allocator*/) { + const CompileOptions& /*options*/) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Run backend " << hlo_module->name(); @@ -128,7 +128,7 @@ StatusOr> InterpreterCompiler::RunBackend( StatusOr>> InterpreterCompiler::Compile( std::unique_ptr module_group, std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { if (module_group->empty()) { return std::vector>(); } @@ -141,12 +141,10 @@ StatusOr>> InterpreterCompiler::Compile( "Unexpected number of StreamExecutor's."); } auto hlo_modules = module_group->ConsumeModules(); - TF_ASSIGN_OR_RETURN(auto module, - RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0], - device_allocator)); - TF_ASSIGN_OR_RETURN( - auto executable, - RunBackend(std::move(module), stream_exec[0][0], device_allocator)); + TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]), + stream_exec[0][0], options)); + TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module), + stream_exec[0][0], options)); std::vector> ret; ret.push_back(std::move(executable)); return std::move(ret); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 824594dfd84..2136bc9ca4a 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -45,14 +45,14 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index aa759b26226..8603cd5f05c 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -24,7 +24,7 @@ namespace xla { StatusOr>> LLVMCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { // Tensorflow tries to enable the following behaviors in all its threads: // // - Denormals are zero (DAZ): roughly, operations treat denormal floats as @@ -48,10 +48,10 @@ StatusOr>> LLVMCompiler::Compile( TF_ASSIGN_OR_RETURN(modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0], - device_allocator)); + options.device_allocator)); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, RunBackend(std::move(modules[i]), stream_execs[i][0], - device_allocator)); + options.device_allocator)); result.push_back(std::move(executable)); } diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index bddda50d3e1..7f0c617da6b 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -66,13 +66,14 @@ class LLVMCompiler : public Compiler { // std::unique_ptr module, // se::StreamExecutor* stream_exec, // se::DeviceMemoryAllocator* device_allocator) + using Compiler::Compile; using Compiler::RunBackend; using Compiler::RunHloPasses; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; protected: ModuleHook user_pre_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0eff81c9a0d..ea8c45d3d46 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -190,11 +190,12 @@ LocalService::CompileExecutables( // single partition computations are built using `BuildExecutables`, fix it, // and remove this special case (provided the performance if similar). if (build_options.num_partitions() == 1) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - BuildExecutable(proto, std::move(module_config), execute_backend_.get(), - executor, build_options.device_allocator(), - build_options.run_backend_only())); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + BuildExecutable(proto, std::move(module_config), + execute_backend_.get(), executor, + {build_options.device_allocator(), + build_options.compile_thread_pool()}, + build_options.run_backend_only())); std::vector> executables; executables.push_back(std::move(executable)); return executables; @@ -206,10 +207,12 @@ LocalService::CompileExecutables( std::vector executors(build_options.num_partitions(), executor); - return BuildExecutables({&proto}, std::move(module_configs), - execute_backend_.get(), {executors}, - build_options.device_allocator(), - build_options.run_backend_only()); + return BuildExecutables( + /*module_protos=*/{&proto}, std::move(module_configs), + execute_backend_.get(), {executors}, + Compiler::CompileOptions{build_options.device_allocator(), + build_options.compile_thread_pool()}, + build_options.run_backend_only()); } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc index f71267935cd..41c75dc86a5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc @@ -28,28 +28,24 @@ bool IsUnimplemented(StatusOr& result) { StatusOr> FailoverCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - auto result = - primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator); + const CompileOptions& options) { + auto result = primary_->RunHloPasses(module->Clone(), stream_exec, options); if (IsUnimplemented(result)) { VLOG(2) << "RunHloPasses resulted in " << result.status() << ", falling back to secondary backend"; - return secondary_->RunHloPasses(std::move(module), stream_exec, - device_allocator); + return secondary_->RunHloPasses(std::move(module), stream_exec, options); } return result; } StatusOr> FailoverCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - auto result = - primary_->RunBackend(module->Clone(), stream_exec, device_allocator); + const CompileOptions& options) { + auto result = primary_->RunBackend(module->Clone(), stream_exec, options); if (IsUnimplemented(result)) { VLOG(2) << "RunBackend resulted in " << result.status() << ", falling back to secondary backend"; - return secondary_->RunBackend(std::move(module), stream_exec, - device_allocator); + return secondary_->RunBackend(std::move(module), stream_exec, options); } return result; } @@ -57,7 +53,7 @@ StatusOr> FailoverCompiler::RunBackend( StatusOr>> FailoverCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { std::vector> result; std::vector> modules = module_group->ConsumeModules(); @@ -67,17 +63,15 @@ StatusOr>> FailoverCompiler::Compile( return Unimplemented( "Model partitioning not implemented for the failover compiler!"); } - auto executable = [stream_execs, device_allocator, i, + auto executable = [stream_execs, &options, i, this](std::unique_ptr module) -> StatusOr> { - TF_ASSIGN_OR_RETURN( - auto processed_module, - primary_->RunHloPasses(std::move(module), stream_execs[i][0], - device_allocator)); - TF_ASSIGN_OR_RETURN( - auto result, - primary_->RunBackend(std::move(processed_module), stream_execs[i][0], - device_allocator)); + TF_ASSIGN_OR_RETURN(auto processed_module, + primary_->RunHloPasses(std::move(module), + stream_execs[i][0], options)); + TF_ASSIGN_OR_RETURN(auto result, + primary_->RunBackend(std::move(processed_module), + stream_execs[i][0], options)); return result; }(modules[i]->Clone()); @@ -85,13 +79,11 @@ StatusOr>> FailoverCompiler::Compile( VLOG(2) << "Compile resulted in " << executable.status() << ", falling back to secondary backend"; TF_ASSIGN_OR_RETURN( - modules[i], - secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0], - device_allocator)); - TF_ASSIGN_OR_RETURN( - executable, - secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0], - device_allocator)); + modules[i], secondary_->RunHloPasses(std::move(modules[i]), + stream_execs[i][0], options)); + TF_ASSIGN_OR_RETURN(executable, + secondary_->RunBackend(std::move(modules[i]), + stream_execs[i][0], options)); } if (!executable.ok()) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h index 05badaa98e1..805a116acf2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h @@ -51,16 +51,16 @@ class FailoverCompiler final : public Compiler { StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 79525a3e29e..4e41737def0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -87,16 +87,16 @@ class MlirCompilerImpl : public MlirCompiler { public: StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) override; + const CompileOptions& options) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, @@ -155,12 +155,12 @@ std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { StatusOr> MlirCompilerImpl::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { // Until we find a reason to do something different, run the same passes // that the normal GPU backend runs. gpu::NVPTXCompiler xla_compiler; TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, - device_allocator)); + options.device_allocator)); TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); return std::move(module); @@ -454,7 +454,7 @@ StatusOr> TransformKernelToXlaThunk( StatusOr> MlirCompilerImpl::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { // Determine the HLO schedule, which is an ordering of HLO instructions. This // is used by buffer assignment to enable buffer reuse, and the same ordering // must also be used to determine the thunk launch schedule. @@ -595,7 +595,7 @@ StatusOr> MlirCompilerImpl::RunBackend( StatusOr>> MlirCompilerImpl::Compile( std::unique_ptr module_group, std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) { + const CompileOptions& options) { return Unimplemented("Not yet implemented in MLIR compiler"); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index a6d23c18797..cf781b4fcdd 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -357,7 +357,7 @@ StatusOr>> Service::BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) { + const Compiler::CompileOptions& options, bool run_backend_only) { VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -387,17 +387,15 @@ StatusOr>> Service::BuildExecutables( std::vector> executables; if (!run_backend_only) { - TF_ASSIGN_OR_RETURN( - executables, - backend->compiler()->Compile(std::move(module_group), - std::move(executors), device_allocator)); + TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile( + std::move(module_group), + std::move(executors), options)); } else { auto modules = module_group->ConsumeModules(); for (std::unique_ptr& module : modules) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend->compiler()->RunBackend(std::move(module), executors[0][0], - device_allocator)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executors[0][0], options)); executables.push_back(std::move(executable)); } } @@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, TF_ASSIGN_OR_RETURN(std::vector> executables, BuildExecutables(module_protos, std::move(module_configs), execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); + {/*device_allocator=*/nullptr})); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator, + se::StreamExecutor* executor, const Compiler::CompileOptions& options, bool run_backend_only) { VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, @@ -822,14 +820,13 @@ StatusOr> Service::BuildExecutable( DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); if (!run_backend_only) { - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); + TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses( + std::move(module), executor, options)); } - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend->compiler()->RunBackend(std::move(module), executor, options)); const auto& debug_opts = module_config->debug_options(); if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) && @@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { BuildExecutable(arg->computation(), std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), - /*device_allocator=*/nullptr)); + {/*device_allocator=*/nullptr})); *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 712ccc44d91..02288bba475 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -235,8 +235,7 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator = nullptr, + se::StreamExecutor* executor, const Compiler::CompileOptions& options, bool run_backend_only = false); // Same as BuildExecutable() above, but builds a list of Executables for the @@ -245,8 +244,7 @@ class Service : public ServiceInterface { const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - se::DeviceMemoryAllocator* device_allocator, - bool run_backend_only = false); + const Compiler::CompileOptions& options, bool run_backend_only = false); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 4dff4dc593e..6e5d77d067d 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -102,27 +102,8 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { } } -// Returns a sharding where each tuple element is chosen as the more specific -// one of the corresponding elements in a and b. Requires a an b to have the -// same tuple nesting. -HloSharding MergeForMoreSpecificSharding(const HloSharding& a, - const HloSharding& b) { - if (a.IsTuple()) { - HloSharding result = a; - CHECK(b.IsTuple()); - CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size()); - for (int64 i = 0; i < result.tuple_elements().size(); ++i) { - result.tuple_elements()[i] = MergeForMoreSpecificSharding( - a.tuple_elements()[i], b.tuple_elements()[i]); - } - return result; - } - return IsShardingMoreSpecific(a, b) ? a : b; -} - // Tries to refine `to_merge` by combining with `old`. Returns if the final -// `to_merge` is more specific than `old`. May combine partial sharding in -// addition to MergeForMoreSpecificSharding(). +// `to_merge` is more specific than `old`. bool MergeSharding(const HloSharding& old, HloSharding* to_merge, bool may_combine_partial_sharding) { if (old.IsTuple()) { @@ -1093,8 +1074,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } auto sharding = instruction->operand(0)->sharding(); if (instruction->has_sharding()) { - sharding = - MergeForMoreSpecificSharding(sharding, instruction->sharding()); + MergeSharding(instruction->sharding(), &sharding, + may_combine_partial_sharding); } return MaybeImproveInstructionSharding(std::move(sharding), instruction, may_combine_partial_sharding); @@ -1320,6 +1301,12 @@ absl::optional GetShardingFromUser( return hlo_sharding_util::ReshapeSharding( user.shape(), instruction.shape(), user.sharding()); } + case HloOpcode::kPad: { + if (&instruction != user.operand(0)) { + return absl::nullopt; + } + return user.sharding(); + } case HloOpcode::kSlice: { return user.sharding(); } @@ -1673,8 +1660,10 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // If instruction is a while, or the root or a parameter of a while body, // then propagate its sharding to the while instruction, to its body root, // and to its condition parameter. - std::function maybe_computation_propagation = - [&](HloInstruction* instruction) { + std::function*)> + maybe_computation_propagation = [&](HloInstruction* instruction, + absl::flat_hash_set* + changed) { auto propagate_to_instruction = [&](HloInstruction* search_inst) { auto related_instructions = get_related_instructions(search_inst); if (absl::c_count(related_instructions, instruction)) { @@ -1683,7 +1672,8 @@ StatusOr ShardingPropagation::Run(HloModule* module) { inst->sharding() != instruction->sharding()) { VLOG(2) << "Add computation sharding: " << inst->name(); inst->set_sharding(instruction->sharding()); - maybe_computation_propagation(inst); + changed->insert(inst); + maybe_computation_propagation(inst, changed); } } } @@ -1785,6 +1775,14 @@ StatusOr ShardingPropagation::Run(HloModule* module) { for (const HloInstruction* instruction : instructions) { already_sharded_counter += (instruction->has_sharding() ? 1 : 0); } + auto clear_cache = [&](HloInstruction* hlo) { + for (auto operand : hlo->operands()) { + already_inferred_from_users.erase(operand); + } + for (auto user : hlo->users()) { + already_inferred_from_operands.erase(user); + } + }; // First iterate the HLO graph in post order taking shardings from // operands. for (HloInstruction* instruction : instructions) { @@ -1799,12 +1797,11 @@ StatusOr ShardingPropagation::Run(HloModule* module) { any_changed = true; VLOG(2) << "Add sharding (forward-pass): " << instruction->ToString(); - maybe_computation_propagation(instruction); - for (auto operand : instruction->operands()) { - already_inferred_from_users.erase(operand); - } - for (auto user : instruction->users()) { - already_inferred_from_operands.erase(user); + absl::flat_hash_set changed_in_comp_prop; + maybe_computation_propagation(instruction, &changed_in_comp_prop); + clear_cache(instruction); + for (auto hlo : changed_in_comp_prop) { + clear_cache(hlo); } changed_last_iter = true; } @@ -1823,12 +1820,11 @@ StatusOr ShardingPropagation::Run(HloModule* module) { ++inferred_from_user_counter; any_changed = true; VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); - maybe_computation_propagation(*it); - for (auto operand : (*it)->operands()) { - already_inferred_from_users.erase(operand); - } - for (auto user : (*it)->users()) { - already_inferred_from_operands.erase(user); + absl::flat_hash_set changed_in_comp_prop; + maybe_computation_propagation(*it, &changed_in_comp_prop); + clear_cache(*it); + for (auto hlo : changed_in_comp_prop) { + clear_cache(hlo); } changed_last_iter = true; } diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index ec83f99db32..1645e01eb33 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -514,6 +514,26 @@ ENTRY %pad { op::Sharding("{devices=[2,2]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PadBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %pad { + %input = f32[11,17]{1,0} parameter(0) + %copy = f32[11,17]{1,0} copy(%input) + %pad_value = f32[] parameter(1) + %pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2, + sharding={devices=[2,2]0,1,2,3} + ROOT %result = f32[27,51]{1,0} copy(%pad) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,2]0,1,2,3}")); +} + TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) { const char* const hlo_string = R"( HloModule module @@ -856,40 +876,41 @@ TEST_F(ShardingPropagationTest, While) { HloModule module %cond { - %vars.cond = (u32[], f32[10]{0}) parameter(0) - %count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0 + %vars.cond = (u32[], f32[10,10]) parameter(0) + %count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0 %limit = u32[] constant(10) ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT } %body { - %vars = (u32[], f32[10]{0}) parameter(0) + %vars = (u32[], f32[10,10]) parameter(0) %count = u32[] get-tuple-element(%vars), index=0 - %acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1 + %acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1 %one = u32[] constant(1) %count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated} - %acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc) - ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1) + %acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc) + ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1) } ENTRY %entry { - %p0 = f32[10]{0} parameter(0) - %p0.copy = f32[10]{0} copy(f32[10]{0} %p0) - %p1 = f32[10]{0} parameter(1) + %p0 = f32[10,10] parameter(0) + %p0.copy = f32[10,10] copy(f32[10,10] %p0) + %p1 = f32[10,10] parameter(1) %zero = u32[] constant(0) - %init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy) - %while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init), + %init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy) + %while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init), body=%body, condition=%cond - %res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1 - %prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1 - %res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev) - ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1) + %res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1 + %prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1 + %res.1 = f32[10,10] multiply(f32[10,10] %res, %prev) + ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1) })"; auto while_is_sharded = [this](HloModule* module, const HloSharding& sharding) { - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation(/*is_spmd=*/true).Run(module)); EXPECT_TRUE(changed); auto while_instr = FindInstruction(module, "while"); EXPECT_NE(nullptr, while_instr); @@ -911,7 +932,7 @@ ENTRY %entry { auto body_root = FindInstruction(module.get(), "tuple"); EXPECT_NE(nullptr, body_root); auto sharding = - ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie(); + ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie(); body_root->set_sharding(sharding); while_is_sharded(module.get(), sharding); } @@ -921,11 +942,30 @@ ENTRY %entry { ParseAndReturnVerifiedModule(hlo_string)); auto acc_1 = FindInstruction(module.get(), "acc.1"); EXPECT_NE(nullptr, acc_1); - acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie()); + acc_1->set_sharding( + ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie()); - while_is_sharded( - module.get(), - ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie()); + while_is_sharded(module.get(), + ParseSharding("{{replicated}, {devices=[2,1]0,1}}") + .ConsumeValueOrDie()); + } + { + // Merge partial sharding from operand and body. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto acc_1 = FindInstruction(module.get(), "acc.1"); + EXPECT_NE(nullptr, acc_1); + acc_1->set_sharding( + ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}") + .ConsumeValueOrDie()); + auto p0 = FindInstruction(module.get(), "p0"); + p0->set_sharding( + ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}") + .ConsumeValueOrDie()); + + while_is_sharded(module.get(), + ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}") + .ConsumeValueOrDie()); } } diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index e33888cc29c..afef62dd5c5 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -182,6 +182,10 @@ class ConvolutionVisitor { return permute_dims[id]; } + int64 ReverseDimLookUp(absl::Span permute_dims, int64 id) { + return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id)); + } + HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter( HloInstruction* instr, int64 depth); @@ -215,9 +219,10 @@ class ConvolutionVisitor { // Limit on batch size to apply this technique on. int64 limit_on_batch_size_; - // We choose the new batch size to be a constant so that space-to-batch - // propagation through several convolutional layers is consistent. - static constexpr int64 kNewBatchSize = 8; + // We choose the new batch size to be kNumSplits times that of the old batch + // so that space-to-batch propagation through several convolutional layers is + // consistent. + static constexpr int64 kNumSplits = 8; // Depth for searching reduce window static constexpr int64 kReduceWindowSearchDepth = 10; @@ -301,17 +306,12 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch( if (old_batch_size > limit_on_batch_size_) { return false; } - // We currently only cater to evenly divisible cases. - if (kNewBatchSize % old_batch_size != 0) { - return false; - } VLOG(1) << "spatial size " << c.spatial_size; - const int64 num_splits = kNewBatchSize / old_batch_size; // If the ratio is not within the 2X range, we can't Halo Pad from the next // split. - if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) { + if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) { return false; } VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString(); @@ -323,6 +323,24 @@ StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( int64 activations_batch_dim, int64 old_batch_size, int64 low_padding, int64 high_padding, int64 halo_size, int64 original_split_dim_size, HloInstruction* pad_val) { + const int64 original_batch_size = + activations->shape().dimensions(activations_batch_dim) / kNumSplits; + + if (original_batch_size > 1) { + std::vector new_dimensions(activations->shape().dimensions().begin(), + activations->shape().dimensions().end()); + new_dimensions[activations_batch_dim] = kNumSplits; + new_dimensions.insert(new_dimensions.begin() + activations_batch_dim, + original_batch_size); + + // Reshape the output of the new conv into the old convolutions shape. + TF_ASSIGN_OR_RETURN(activations, + MakeReshapeHlo(new_dimensions, activations)); + + spatial_dimension_to_split++; + activations_batch_dim++; + } + const int64 rank = activations->shape().rank(); const int64 spatial_split_size = activations->shape().dimensions(spatial_dimension_to_split); @@ -415,6 +433,21 @@ StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region}, spatial_dimension_to_split)); } + + if (original_batch_size > 1) { + std::vector new_dimensions(activations->shape().dimensions().begin(), + activations->shape().dimensions().end()); + new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits; + new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1); + + // Reshape the output of the new conv into the old convolutions shape. + TF_ASSIGN_OR_RETURN(activations, + MakeReshapeHlo(new_dimensions, activations)); + + spatial_dimension_to_split++; + activations_batch_dim++; + } + VLOG(1) << "HaloDuplicated activations " << activations->ToString(); return activations; } @@ -424,17 +457,20 @@ ConvolutionVisitor::BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64& spatial_dimension_to_split, int64& activations_batch_dim, bool is_backprop) { - std::vector transpose_dims; - ConvolutionDimensionNumbers new_dim_numbers = dim_numbers; - if (spatial_dimension_to_split != activations_batch_dim + 1) { + std::vector transpose_dims(activations->shape().rank()); + if (spatial_dimension_to_split == activations_batch_dim + 1) { + absl::c_iota(transpose_dims, 0); + } else { + ConvolutionDimensionNumbers new_dim_numbers = dim_numbers; int64 pushed_counter = 0; int64 new_batch_dim, new_spatial_dim; + int64 dim_counter = 0; for (int i = 0; i < activations->shape().rank(); ++i) { if (i == activations_batch_dim) { continue; } if (i == spatial_dimension_to_split) { - transpose_dims.push_back(activations_batch_dim); + transpose_dims[dim_counter++] = activations_batch_dim; new_batch_dim = pushed_counter; pushed_counter++; new_spatial_dim = pushed_counter; @@ -452,7 +488,7 @@ ConvolutionVisitor::BringSpaceNextToBatch( } } } - transpose_dims.push_back(i); + transpose_dims[dim_counter++] = i; pushed_counter++; } @@ -460,14 +496,14 @@ ConvolutionVisitor::BringSpaceNextToBatch( spatial_dimension_to_split = new_spatial_dim; TF_ASSIGN_OR_RETURN(activations, MakeTransposeHlo(activations, transpose_dims)); - } - if (is_backprop) { - new_dim_numbers.set_input_feature_dimension(activations_batch_dim); - } else { - new_dim_numbers.set_input_batch_dimension(activations_batch_dim); + if (is_backprop) { + new_dim_numbers.set_input_feature_dimension(activations_batch_dim); + } else { + new_dim_numbers.set_input_batch_dimension(activations_batch_dim); + } + dim_numbers = new_dim_numbers; } - dim_numbers = new_dim_numbers; return SpaceNextToBatchDetails{activations, transpose_dims}; } @@ -586,12 +622,23 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, VLOG(1) << "Checking if conv is supported for propagation " << consumer->ToString(); if (IsConvSuitableForSpaceToBatch(consumer)) { - for (int64 i = 0; i < consumer->operand_count(); ++i) { - auto old_producer = consumer->mutable_operand(i); - if (i == 0 && !old_to_new_instrs_.contains(old_producer)) { - return false; - } + if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) { + return false; } + auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)]; + // Make sure that the space dimension is the same across the producer + // and consumer. + if (consumer->convolution_dimension_numbers().input_spatial_dimensions( + get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) { + return false; + } + // Make sure that the batch dimension is the same across the producer + // and consumer. + if (consumer->convolution_dimension_numbers().input_batch_dimension() != + dim_map_val_op_0.first) { + return false; + } + return true; } @@ -611,13 +658,35 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, VLOG(2) << "Checking for backprop filter conv operands " << consumer->operand_count(); - if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) { + auto activations = consumer->mutable_operand(0); + auto kernel = consumer->mutable_operand(1); + + if (!old_to_new_instrs_.contains(kernel)) { VLOG(2) << "Backprop filter conv not ready for propagation because of " "kernel is not space-to-batched"; return false; } - if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) { + if (!old_to_new_instrs_.contains(activations)) { + const int64 lhs_batch = activations->shape().dimensions( + consumer->convolution_dimension_numbers().input_feature_dimension()); + auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)]; + const int64 old_batch_dim = dim_map_val_op_1.first; + auto second_operand = old_to_new_instrs_[kernel]; + auto permute_dims_second_operand = + instr_to_dim_permute_map_[second_operand]; + const int64 new_batch_dim = + DimLookUp(permute_dims_second_operand, old_batch_dim); + const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim); + + // Because we want to convert activations into a space-to-batched version + // only for backprop filter convolutions, we want to make sure that the + // batch dimensions (feature dimensions, technically) are same sized. + // Since RHS is already space-to-batched, we need to account for it too. + if (rhs_batch != kNumSplits * lhs_batch) { + return false; + } + // If activations have not been propagated through, we can do // space-to-batch on them provided kernel has been propagated. VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, " @@ -625,10 +694,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return true; } - auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)]; - auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)]; - auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)]; - auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)]; + auto first_operand = old_to_new_instrs_[activations]; + auto dim_map_val_op_0 = instr_to_dim_map_[activations]; + auto second_operand = old_to_new_instrs_[kernel]; + auto dim_map_val_op_1 = instr_to_dim_map_[kernel]; auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand]; auto permute_dims_second_operand = @@ -1119,7 +1188,7 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, Window new_win; for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) { - auto dim = DimLookUp(permute_dims, i); + auto dim = ReverseDimLookUp(permute_dims, i); new_win.add_dimensions(); new_win.mutable_dimensions(i)->set_stride( consumer->window().dimensions(dim).stride()); @@ -1339,7 +1408,9 @@ StatusOr ConvolutionVisitor::SelectValidPortion( const int64 new_space_size = new_shape.dimensions(new_space_dim); const int64 old_batch_size = old_shape.dimensions(old_batch_dim); const int64 old_space_size = old_shape.dimensions(old_space_dim); - CHECK_EQ(new_batch_size % old_batch_size, 0); + CHECK_EQ(new_batch_size % old_batch_size, 0) + << " New batch size " << new_batch_size << " old batch size " + << old_batch_size; const int64 num_splits = new_batch_size / old_batch_size; // Build a constant PRED to decide which elements in the split dimension // are from halo. @@ -1394,8 +1465,10 @@ StatusOr ConvolutionVisitor::BatchToSpace( CHECK(old_to_new_instrs_.contains(old_instr)); auto new_instr = old_to_new_instrs_[old_instr]; VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim " - << old_space_dim << " new_instr " << new_instr->ToString() - << " permute dims " << instr_to_dim_permute_map_.count(new_instr); + << old_space_dim << " old_instr " << old_instr->ToString() + << "\n new_instr " << new_instr->ToString() << " permute dims " + << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size " + << old_batch_size; CHECK(instr_to_dim_permute_map_.contains(new_instr)); auto permute_dims = instr_to_dim_permute_map_[new_instr]; const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim); @@ -1565,6 +1638,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { c.spatial_dimension_to_split, activations_batch_dim)); activations_new = retval.instr; std::vector trans_dims = retval.transpose_dims; + CHECK(!trans_dims.empty()); auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations_new->shape().element_type()))); @@ -1578,8 +1652,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { VLOG(1) << "spatial size " << c.spatial_size; - const int64 num_splits = kNewBatchSize / old_batch_size; - + const int64 num_splits = kNumSplits; const int64 output_offsets = convolution->shape().dimensions( permuted_conv_dims_numbers.output_spatial_dimensions( get_chosen_spatial_dim(convolution))); @@ -1614,6 +1687,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { activations_new->shape().dimensions().end()); const int64 reshaped_space_size = new_space_size * new_batch_size / old_batch_size; + VLOG(3) << "Increasing the spatial size while propagating new_batch_size " + << new_batch_size << " old_batch_size " << old_batch_size; new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size; new_dimensions[activations_batch_dim] = old_batch_size; @@ -1621,10 +1696,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations, MakeReshapeHlo(new_dimensions, activations_new)); + VLOG(3) << "First reshape done"; PaddingConfig padding_config = MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); padding_config.mutable_dimensions(c.spatial_dimension_to_split) - ->set_edge_padding_high(spatial_split_size * new_batch_size - + ->set_edge_padding_high(spatial_split_size * new_batch_size / + old_batch_size - reshaped_space_size); padding_config.mutable_dimensions(c.spatial_dimension_to_split) ->set_edge_padding_low(0); @@ -1647,6 +1724,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { reshaped_activations, MakeReshapeHlo(reshape_back_dims, reshaped_activations)); + VLOG(3) << "Second reshape done"; + TF_ASSIGN_OR_RETURN( activations_new, HaloDuplicateWithSlice( @@ -1664,6 +1743,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { // additional space available, and adjust the required slice size (and // thereby the halo size). if (spatial_split_size < new_space_size) { + VLOG(3) << "Decreasing the spatial size while propagating"; const int64 additional_space_present = spatial_split_size % c.stride; spatial_split_size = new_space_size; slice_size = @@ -1758,6 +1838,7 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations, activations = retval.instr; std::vector transpose_dims = retval.transpose_dims; + CHECK(!transpose_dims.empty()); // Because we are splitting the spatial dimension, if convolution needed // padding in the spatial dimension, we materialize it. if (high_padding || low_padding) { @@ -1774,7 +1855,9 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations, MakePadHlo(activations, padding, padding_config)); } VLOG(1) << "Initial padded activations shape " - << activations->shape().ToString(); + << activations->shape().ToString() << " old_batch_size " + << old_batch_size << " activations_batch_dim " + << activations_batch_dim; // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16] // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo @@ -1829,7 +1912,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( CHECK(old_to_new_instrs_.contains(kernel_old)); auto kernel_new = old_to_new_instrs_[kernel_old]; + auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; + HloInstruction* activations_new = nullptr; + bool activations_locally_space_to_batched = false; // If activations were no space-to-batched, we space-to-batch them below. if (!old_to_new_instrs_.contains(activations_old)) { VLOG(1) << "Space-to-batching activations to enable space-to-depth"; @@ -1838,28 +1924,34 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( instr_to_dim_map_[activations_old] = std::make_pair(prev_feature_dim, prev_batch_dim); - int64 activations_batch_dim = original_conv_dims.input_feature_dimension(); - const int64 old_batch_size = - activations_old->shape().dimensions(activations_batch_dim); - const int64 num_splits = kNewBatchSize / old_batch_size; + const int64 new_kernel_space_dim = + DimLookUp(permute_dims_kernel, kernel_space_dim); + const int64 new_kernel_split_dim_size = - kernel_new->shape().dimensions(kernel_space_dim); + kernel_new->shape().dimensions(new_kernel_space_dim); const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size; const int64 pad_size = - needed_spatial_size * num_splits - old_split_dim_size; + needed_spatial_size * kNumSplits - old_split_dim_size; ConvolutionDimensionNumbers tmp_dim_numbers; tmp_dim_numbers = original_conv_dims; TF_ASSIGN_OR_RETURN( auto retval, SplitSpace(activations_old, tmp_dim_numbers, old_space_dim, - activations_batch_dim, + old_batch_dim, /*high_padding=*/pad_size, /*low_padding=*/0, - needed_spatial_size, num_splits, /*is_backprop=*/true)); + needed_spatial_size, kNumSplits, /*is_backprop=*/true)); old_to_new_instrs_[activations_old] = retval.first; - instr_to_dim_permute_map_[retval.first] = retval.second; - VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString(); + std::vector reversed_transpose_dims(retval.second.size()); + for (int64 i = 0; i < retval.second.size(); ++i) { + reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i); + } + instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims; + + VLOG(3) << "New Activations " << retval.first->ToString(); + + activations_locally_space_to_batched = true; } CHECK(old_to_new_instrs_.contains(activations_old)); @@ -1884,7 +1976,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( i, DimLookUp(permute_dims, original_conv_dims.input_spatial_dimensions(i))); permuted_conv_dims_numbers.set_kernel_spatial_dimensions( - i, DimLookUp(permute_dims, + i, DimLookUp(permute_dims_kernel, original_conv_dims.kernel_spatial_dimensions(i))); } @@ -1905,10 +1997,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( previous_spatial_dim_count, previous_chosen_spatial_dim_in_output); const int64 kernel_input_feature_dim = DimLookUp( - permute_dims, original_conv_dims.kernel_input_feature_dimension()); + permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension()); - const int64 kernel_output_feature_dim = DimLookUp( - permute_dims, original_conv_dims.kernel_output_feature_dimension()); + const int64 kernel_output_feature_dim = + DimLookUp(permute_dims_kernel, + original_conv_dims.kernel_output_feature_dimension()); permuted_conv_dims_numbers.set_kernel_input_feature_dimension( kernel_input_feature_dim); @@ -1931,7 +2024,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( VLOG(1) << "Propagating on conv activations_batch_dim " << activations_batch_dim << " spatial_dimension_to_split " - << spatial_dimension_to_split << " old_batch_size " << old_batch_size; + << spatial_dimension_to_split << " old_batch_size " << old_batch_size + << " new_split_dim_size " << new_split_dim_size; TF_ASSIGN_OR_RETURN( auto retval, @@ -1939,6 +2033,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( spatial_dimension_to_split, activations_batch_dim, /*is_backprop=*/true)); std::vector transpose_dims = retval.transpose_dims; + CHECK(!transpose_dims.empty()); activations_new = retval.instr; VLOG(1) << "Activations_new post BringSpaceNextToBatch " @@ -1949,13 +2044,15 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations_new->shape().element_type()))); - // Select activations correctly by masking additional space. - TF_ASSIGN_OR_RETURN( - activations_new, - SelectValidPortion(activations_new, activations_old, select_val, - activations_batch_dim, spatial_dimension_to_split, - old_batch_dim, old_space_dim)); - + if (!activations_locally_space_to_batched) { + // Select activations correctly by masking additional space. + TF_ASSIGN_OR_RETURN( + activations_new, + SelectValidPortion(activations_new, activations_old, select_val, + activations_batch_dim, spatial_dimension_to_split, + old_batch_dim, old_space_dim)); + } + VLOG(3) << "Selecting the valid kernel area"; // Select kernel correctly by masking additional space. TF_ASSIGN_OR_RETURN( kernel_new, @@ -2238,7 +2335,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( VLOG(1) << "spatial size " << c.spatial_size; - const int64 num_splits = kNewBatchSize / old_batch_size; auto original_conv = convolution; const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions( @@ -2246,13 +2342,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( const int64 output_offsets = convolution->shape().dimensions(output_spatial_dim); const int64 output_offsets_per_split = - CeilOfRatio(output_offsets, num_splits); + CeilOfRatio(output_offsets, kNumSplits); int64 spatial_split_size = CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; // Keep increasing the split size so that overall size isn't smaller than the // original spatial dimension. - while (spatial_split_size * num_splits - c.spatial_size < 0) { + while (spatial_split_size * kNumSplits - c.spatial_size < 0) { spatial_split_size += c.stride; } @@ -2276,12 +2372,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( const int64 slice_size = spatial_split_size + c.halo_size; // Pad spatial dim. - const int64 pad_size = spatial_split_size * num_splits - c.spatial_size; + const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size; VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " << c.stride << " slice_size " << slice_size; VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split - << " num_splits " << num_splits << " kernel_spatial_dim_size " + << " num_splits " << kNumSplits << " kernel_spatial_dim_size " << c.kernel_spatial_dim_size; int64 spatial_dimension_to_split = c.spatial_dimension_to_split; TF_ASSIGN_OR_RETURN( @@ -2292,7 +2388,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( /*low_padding=*/c.base_dilation_factor == 1 ? c.inherent_low_padding : 0, - spatial_split_size, num_splits)); + spatial_split_size, kNumSplits)); HloInstruction* batch_increased_reshape = retval.first; convolution->SetupDerivedInstruction(batch_increased_reshape); diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc index 96ba73ac9f0..1763ed6090e 100644 --- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) { } } +TEST_F(DynamismInferenceTest, InferThroughPad) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + // Test the analysis on a gather. + auto operand1 = ConstantR1(&b, {1, 2}); + auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_edge_padding_high(1); + // After pad the value is [constant, constant, parameter]. + auto pad = Pad(operand1, parameter, padding_config); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + // Everything is constant, result is also contant. + EXPECT_FALSE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({0})); + EXPECT_FALSE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({1})); + EXPECT_TRUE( + ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({2})); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 0ddb01fc6ab..49e7560d2a8 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler { StatusOr>> CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, se::StreamExecutor* stream_exec) { + GpuVersion gpu_version, se::StreamExecutor* stream_exec, + bool relocatable) { if (user_post_optimization_hook_) { user_post_optimization_hook_(*llvm_module); } diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt index 3654286cc35..d5fde0537d8 100644 --- a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt @@ -20,9 +20,8 @@ op { "A `Tensor` of type T. An alias of `x`. The content " "of `y` is undefined if there are duplicates in `i`." } - summary: <RefCountIsOne(); } if (is_first_ref) { - TF_RETURN_IF_ERROR( - func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info)); + TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces)); TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library)); if (!add_to_local_only) { return MaybeRegisterFunctionRemotely(fdef); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 9e0d9fcfe5d..886ed498c07 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -234,8 +234,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // entry to the KernelAndDevice cache for it if it's not exist. Status AddFunctionDef(const FunctionDef& fdef) override; - Status AddFunctionDefWithDebugInfo( - const FunctionDef& fdef, const Graph* graph_with_debug_info) override; + Status AddFunctionDefWithStackTraces( + const FunctionDef& fdef, const StackTracesMap& stack_traces) override; // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add // it to the local FunctionLibraryDefinition as well, but no need to add it @@ -244,7 +244,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status AddFunctionDef(const FunctionDef& fdef, const FunctionDefLibrary& library, const bool add_to_local_only = false, - const Graph* graph_with_debug_info = nullptr); + const StackTracesMap& stack_traces = {}); const FunctionDef* GetFunctionDef(const string& function_name); diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc index d5ada59a8f0..5a0e60cf8f0 100644 --- a/tensorflow/core/common_runtime/function_def_utils.cc +++ b/tensorflow/core/common_runtime/function_def_utils.cc @@ -35,14 +35,22 @@ Status FunctionDefToBodyHelper( InstantiationResult result; TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); - std::unique_ptr graph(new Graph(lib_def)); + auto graph = absl::make_unique(lib_def); graph->SetConstructionContext(ConstructionContext::kFunctionDef); - GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); + const StackTracesMap& stack_traces = + lib_def->GetStackTraces(fdef.signature().name()); + for (Node* n : graph->nodes()) { + auto it = stack_traces.find(n->name()); + if (n && it != stack_traces.end()) { + n->SetStackTrace(it->second); + } + } + // Call BuildControlFlowInfo to validate that this function body has // well-formed control flow. std::vector dummy; diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c52b4c50893..ff326e172b2 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -393,6 +393,35 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); } +TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) { + class DummyStackTrace : public AbstractStackTrace { + absl::Span ToFrames() const override { return {}; } + + std::string ToString(const TracePrintingOptions& opts) const override { + return "DummyStackTrace"; + } + }; + + FunctionDef func = test::function::XTimesTwo(); + Init({}); + + StackTracesMap stack_traces; + stack_traces["two"] = std::make_shared(); + + TF_CHECK_OK(lib_def_->AddFunctionDef(func, stack_traces)); + + FunctionLibraryRuntime::Handle handle; + TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {}, &handle)); + + const FunctionBody* func_body = flr0_->GetFunctionBody(handle); + for (const Node* node : func_body->graph->nodes()) { + if (node->name() == "two") { + EXPECT_EQ(node->GetStackTrace()->ToString({}), "DummyStackTrace"); + } + } + TF_CHECK_OK(flr0_->ReleaseHandle(handle)); +} + TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) { Init({test::function::XTimesTwo()}); auto x = test::AsTensor({1, 2, 3, 4}); diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 92b07682d76..639739e9cac 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/public/version.h" @@ -1425,6 +1426,17 @@ void GraphConstructor::Undo() { Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, int input_index) { + if (output_index >= src->num_outputs()) { + return errors::InvalidArgument( + "Output ", output_index, " of node ", src->name(), + " does not exist. Node only has ", src->num_outputs(), " outputs."); + } + if (input_index >= dst->num_inputs()) { + return errors::InvalidArgument( + "Input ", input_index, " of node ", dst->name(), + " does not exist. Node only has ", dst->num_inputs(), " inputs."); + } + DataType src_out = src->output_type(output_index); DataType dst_in = dst->input_type(input_index); if (!TypesCompatible(dst_in, src_out)) { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 50d2b96b506..b37a9ca2535 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1174,13 +1174,13 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { FunctionLibraryDefinition::FunctionDefAndOpRegistration:: FunctionDefAndOpRegistration(const FunctionDef& fdef_in, - const Graph* graph_with_debug_info) + const StackTracesMap& stack_traces) : fdef(fdef_in), // Exact shape inference for functions is handled by ShapeRefiner. // Here we pass a dummy shape inference function for legacy code paths. op_registration_data(fdef.signature(), shape_inference::UnknownShape, true /* is_function */), - graph_with_debug_info(graph_with_debug_info) {} + stack_traces(stack_traces) {} FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) @@ -1233,14 +1233,14 @@ FunctionLibraryDefinition::FindHelper(const string& func) const { } Status FunctionLibraryDefinition::AddFunctionDef( - const FunctionDef& fdef, const Graph* graph_with_debug_info) { + const FunctionDef& fdef, const StackTracesMap& stack_traces) { mutex_lock l(mu_); bool added; - return AddFunctionDefHelper(fdef, graph_with_debug_info, &added); + return AddFunctionDefHelper(fdef, stack_traces, &added); } Status FunctionLibraryDefinition::AddFunctionDefHelper( - const FunctionDef& fdef, const Graph* graph_with_debug_info, bool* added) { + const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) { *added = false; std::shared_ptr& entry = function_defs_[fdef.signature().name()]; @@ -1260,8 +1260,7 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper( "Cannot add function '", fdef.signature().name(), "' because an op with the same name already exists."); } - entry = std::make_shared(fdef, - graph_with_debug_info); + entry = std::make_shared(fdef, stack_traces); *added = true; return Status::OK(); } @@ -1403,7 +1402,7 @@ Status FunctionLibraryDefinition::AddLibrary( Status s; bool added; for (const FunctionDef& fdef : lib_def.function()) { - s = AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added); + s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; @@ -1430,8 +1429,7 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func, mutex_lock l(mu_); bool added; TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); - TF_RETURN_IF_ERROR( - AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added)); + TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added)); return Status::OK(); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3951caa7500..544fab8a715 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -348,6 +348,10 @@ class AbstractStackTrace { virtual std::string ToString(const TracePrintingOptions& opts) const = 0; }; +using StackTracesMap = + std::unordered_map>; + // Helper to maintain a map between function names in a given // FunctionDefLibrary and function definitions. // @@ -397,7 +401,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Associates `graph` with a function `func_name`. Lifetime assumption: // `graph` has to outlive all instantiated graphs. Status AddFunctionDef(const FunctionDef& fdef, - const Graph* graph_with_debug_info = nullptr) + const StackTracesMap& stack_traces = {}) TF_LOCKS_EXCLUDED(mu_); // Adds gradient definition 'grad' to this function library. @@ -509,10 +513,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Returns graph with debug stack traces for the given function, or `nullptr` // if none found. - const Graph* GetGraphWithDebugInfo(const std::string& func_name) const { + const StackTracesMap& GetStackTraces(const std::string& func_name) const { tf_shared_lock l(mu_); std::shared_ptr entry = FindHelper(func_name); - return entry ? entry->graph_with_debug_info : nullptr; + if (entry) { + return entry->stack_traces; + } + static const auto* empty_map = new StackTracesMap; + return *empty_map; } private: @@ -520,12 +528,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { struct FunctionDefAndOpRegistration { explicit FunctionDefAndOpRegistration( - const FunctionDef& fdef_in, - const Graph* graph_with_debug_info = nullptr); + const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {}); const FunctionDef fdef; const OpRegistrationData op_registration_data; - const Graph* graph_with_debug_info; + const StackTracesMap stack_traces; }; std::shared_ptr FindHelper( @@ -539,7 +546,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Same as AddFunctionDef/AddGradientDef except these methods set // `added` to true if the `fdef`/`grad` were actually added to this. Status AddFunctionDefHelper(const FunctionDef& fdef, - const Graph* graph_with_debug_info, bool* added) + const StackTracesMap& stack_traces, bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); Status AddGradientDefHelper(const GradientDef& grad, bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index cc985284dac..7b60864a085 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -37,6 +37,88 @@ inline bool IsAutotuneNode(const std::shared_ptr node) { // Wrapper for the square function to reduce verbosity. inline double Square(double x) { return x * x; } +// Collects "essential" parallelism parameters and buffer size parameters in the +// tree rooted in the given node. Which parallelism parameters are essential is +// determined by the relative processing time spent in the corresponding +// transformation. The collected parameters are returned via maps that map node +// names to their respective parameters. +inline void CollectParameters( + std::shared_ptr node, + const absl::flat_hash_map>& parameters, + absl::flat_hash_map>* + parallelism_parameters, + absl::flat_hash_map>* + buffer_size_parameters) { + // Parallelism parameter is considered to be essential if the corresponding + // transformations's processing time is greater than essential rate times the + // average transformation self processing time. + constexpr double kEssentialRate = 0.3L; + + absl::flat_hash_map processing_times; + double processing_time = node->TotalProcessingTime(&processing_times); + double uniform_share = + processing_time / static_cast(processing_times.size()); + for (auto& pair : parameters) { + if (pair.second->name == kParallelism && + processing_times[pair.first] > kEssentialRate * uniform_share) { + parallelism_parameters->insert(pair); + } else if (pair.second->name == kBufferSize) { + buffer_size_parameters->insert(pair); + } + } +} + +// Applies the gradient descent method once and updates the parameter values. If +// the new value is out of the range, bound it within the range between the +// minimal and maximum values. +inline void UpdateParameterValues( + const absl::flat_hash_map& gradients, + absl::flat_hash_map>* parameters) { + // Gradient descent step size. + constexpr double kDescentStep = 0.1L; + double new_value; + + double max_abs_derivative = 1.0; + for (auto& pair : *parameters) { + if (std::round(pair.second->value) != pair.second->max) { + auto* gradient = gtl::FindOrNull(gradients, pair.first); + if (gradient) { + max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient)); + } + } + } + for (auto& pair : *parameters) { + auto* gradient = gtl::FindOrNull(gradients, pair.first); + if (gradient) { + new_value = + pair.second->value - kDescentStep * (*gradient) / max_abs_derivative; + // Projection on a feasible interval. + if (new_value > pair.second->max) { + pair.second->value = pair.second->max; + } else if (new_value < pair.second->min) { + pair.second->value = pair.second->min; + } else { + pair.second->value = new_value; + } + } + } +} + +// Copies the parameter values (which are for optimization tuning) and updates +// the state values (which are for the input pipeline to follow). +inline void UpdateStateValues( + absl::flat_hash_map>* parameters) { + VLOG(2) << "Number of tunable parameters: " << parameters->size(); + for (auto& pair : *parameters) { + auto& parameter = pair.second; + VLOG(2) << "Setting tunable parameter " << pair.first << " to " + << parameter->value; + mutex_lock l(*parameter->state->mu); + parameter->state->value = parameter->value; + parameter->state->cond_var->notify_all(); + } +} + // The first input of InterleaveMany corresponds to the input dataset whose // elements are used to create the (derived) input datasets whose elements are // interleaved as output. @@ -1406,27 +1488,37 @@ Model::CollectTunableParameters(std::shared_ptr node) { return parameters; } -absl::flat_hash_map> -Model::CollectEssentialParallelism( - std::shared_ptr node, - const absl::flat_hash_map>& parameters) { - // Parallelism parameter is considered to be essential if the corresponding - // transformations's processing time is greater than essential rate times the - // average transformation self processing time. - constexpr double kEssentialRate = 0.3L; +bool Model::ShouldStop( + int64 cpu_budget, int64 ram_budget, + const absl::flat_hash_map>& parameters, + const absl::flat_hash_map>& + parallelism_parameters, + const absl::flat_hash_map>& + buffer_size_parameters, + std::shared_ptr snapshot, bool* cpu_budget_reached) { + if (!(*cpu_budget_reached)) { + // If those essential transformations' parallelism reaches the CPU + // budget, we will only tune the buffer size parameters in future + // iterations. + int64 model_parallelism = 0; + for (auto& pair : parallelism_parameters) { + model_parallelism += std::round(pair.second->value); + } + *cpu_budget_reached = (model_parallelism > cpu_budget); + } - absl::flat_hash_map processing_times; - double processing_time = node->TotalProcessingTime(&processing_times); - double uniform_share = - processing_time / static_cast(processing_times.size()); - absl::flat_hash_map> essential_parameters; - for (auto& pair : parameters) { - if (pair.second->name == kParallelism && - processing_times[pair.first] > kEssentialRate * uniform_share) { - essential_parameters.insert(pair); + bool all_max = true; + for (auto& pair : + (*cpu_budget_reached ? buffer_size_parameters : parameters)) { + if (std::round(pair.second->value) < pair.second->max) { + all_max = false; + break; } } - return essential_parameters; + + // If all parameters have reached their maximum values or RAM budget is + // reached, we stop the iterations. + return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget; } void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget, @@ -1438,12 +1530,16 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget, } VLOG(2) << "Starting optimization of tunable parameters with GradientDescent"; auto parameters = CollectTunableParameters(snapshot); - auto essential_parameters = CollectEssentialParallelism(snapshot, parameters); + // The maps of "essential" parallelism parameters and buffer size parameters. + absl::flat_hash_map> + parallelism_parameters, buffer_size_parameters; + CollectParameters(snapshot, parameters, ¶llelism_parameters, + &buffer_size_parameters); + + // Initialize the parameter values to minimal before tuning. for (auto& pair : parameters) { pair.second->value = pair.second->min; } - // Gradient descent step size. - constexpr double kDescentStep = 0.1L; // Optimization is stopped once the `OutputTime` improvement is smaller than // this value. @@ -1454,53 +1550,34 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget, double output_time = 0; double new_output_time; - double new_value; - for (int i = 0; i < kMaxIterations; ++i) { + + // When the CPU budget is reached, the parallelism parameter values are fixed + // and we only increase the buffer size parameters. + bool cpu_budget_reached = false; + + for (int i = 0; + i < kMaxIterations && + !ShouldStop(cpu_budget, ram_budget, parameters, parallelism_parameters, + buffer_size_parameters, snapshot, &cpu_budget_reached); + ++i) { absl::flat_hash_map gradients; new_output_time = OutputTime(snapshot, model_input_time, &gradients); - int64 model_parallelism = 0; - for (auto& pair : essential_parameters) { - model_parallelism += std::round(pair.second->value); - } // We terminate once the improvement of the output latency is too small or // the essential transformations' parallelism reaches the CPU budget or the // worst-case total buffer size exceeds the memory budget. - if (std::abs(output_time - new_output_time) < kOptimizationPrecision || - model_parallelism > cpu_budget || - TotalMaximumBufferedBytes(snapshot) > ram_budget) { + if (std::abs(output_time - new_output_time) < kOptimizationPrecision) { break; } - double max_abs_derivative = 1.0; - for (auto& pair : parameters) { - if (pair.second->value != pair.second->max) { - max_abs_derivative = - std::max(max_abs_derivative, std::abs(gradients[pair.first])); - } - } - for (auto& pair : parameters) { - new_value = pair.second->value - - kDescentStep * gradients[pair.first] / max_abs_derivative; - // Projection on a feasible interval. - if (new_value > pair.second->max) { - pair.second->value = pair.second->max; - } else if (new_value < pair.second->min) { - pair.second->value = pair.second->min; - } else { - pair.second->value = new_value; - } - } + + UpdateParameterValues( + gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters)); output_time = new_output_time; } - VLOG(2) << "Number of tunable parameters: " << parameters.size(); + for (auto& pair : parameters) { pair.second->value = std::round(pair.second->value); - auto& parameter = pair.second; - VLOG(2) << "Setting tunable parameter " << pair.first << " to " - << parameter->value; - mutex_lock l(*parameter->state->mu); - parameter->state->value = parameter->value; - parameter->state->cond_var->notify_all(); } + UpdateStateValues(¶meters); } void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget, @@ -1517,6 +1594,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget, // improvement is greater than this constant. constexpr double kBufferSizeMinDelta = 1.0L; + // Initialize the parameter values to minimal before tuning. for (auto& pair : parameters) { pair.second->value = pair.second->min; } @@ -1560,15 +1638,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget, } best_parameter->value++; } - VLOG(2) << "Number of tunable parameters: " << parameters.size(); - for (auto& pair : parameters) { - auto& parameter = pair.second; - VLOG(2) << "Setting tunable parameter " << pair.first << " to " - << parameter->value; - mutex_lock l(*parameter->state->mu); - parameter->state->value = parameter->value; - parameter->state->cond_var->notify_all(); - } + UpdateStateValues(¶meters); } double Model::OutputTime(std::shared_ptr node, double model_input_time, diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 53365d2b304..df8236c5d4f 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -644,16 +644,17 @@ class Model { absl::flat_hash_map> CollectTunableParameters(std::shared_ptr node); - // Collects "essential" parallelism parameters of transformations in the tree - // rooted in the given node. Which parameters are essential is determined by - // comparison the processing time spent in the corresponding transformation - // relative to other transformations. The collected parameters are returned - // as a mapping from a (unique) node name to a parallelism parameter. - absl::flat_hash_map> - CollectEssentialParallelism( - std::shared_ptr node, + // Determines if we should stop the gradient descent optimization iterations + // based on number of increasable parameters, CPU budget, RAM budget and + // current resource usage. + bool ShouldStop( + int64 cpu_budget, int64 ram_budget, + const absl::flat_hash_map>& parameters, const absl::flat_hash_map>& - parameters); + parallelism_parameters, + const absl::flat_hash_map>& + buffer_size_parameters, + std::shared_ptr snapshot, bool* cpu_budget_reached); // This optimization algorithm starts by setting all tunable parallelism // parameters to the minimum value. It then repeatedly identifies the diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto index ad109a3b814..756c8e4e33e 100644 --- a/tensorflow/core/framework/op_def.proto +++ b/tensorflow/core/framework/op_def.proto @@ -8,6 +8,7 @@ option java_package = "org.tensorflow.framework"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"; import "tensorflow/core/framework/attr_value.proto"; import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/resource_handle.proto"; // Defines an operation. A NodeDef in a GraphDef specifies an Op by // using the "op" field which should match the name of a OpDef. @@ -42,6 +43,9 @@ message OpDef { // type, type_attr, and number_attr may be specified. string type_list_attr = 6; + // The handle data for resource inputs. + repeated ResourceHandleProto.DtypeAndShape handle_data = 7; + // For inputs: if true, the inputs are required to be refs. // By default, inputs can be either refs or non-refs. // For outputs: if true, outputs are refs, otherwise they are not. diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6fcbe371caa..40094932814 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -170,17 +170,22 @@ bool MemoryOptimizerEnabled( #define MK_OPT(NAME, VALUE) \ if (optimizer == NAME) return std::unique_ptr(VALUE) -bool MetaOptimizer::IsSingleThreadedExecutor() const { - return config_proto_.experimental().executor_type() == - "SINGLE_THREADED_EXECUTOR"; +bool MetaOptimizer::LowerControlFlow() const { + if (config_proto_.experimental().executor_type() == + "SINGLE_THREADED_EXECUTOR") + return false; + + if (config_proto_.experimental().use_tfrt()) return false; + + return true; } std::unique_ptr MetaOptimizer::MakeNewOptimizer( const string& optimizer) const { MK_OPT("pruning", new ModelPruner()); - MK_OPT("function", new FunctionOptimizer( - cfg_.function_optimization(), - /*lower_control_flow=*/!IsSingleThreadedExecutor())); + MK_OPT("function", + new FunctionOptimizer(cfg_.function_optimization(), + /*lower_control_flow=*/LowerControlFlow())); MK_OPT("constfold", new ConstantFolding( cpu_device_, @@ -235,7 +240,7 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.function_optimization() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique( cfg_.function_optimization(), - /*lower_control_flow=*/!IsSingleThreadedExecutor())); + /*lower_control_flow=*/LowerControlFlow())); } if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF && cfg_.arithmetic_optimization() != RewriterConfig::OFF) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index b21ea68f720..d3b489b224b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -61,7 +61,8 @@ class MetaOptimizer : public GraphOptimizer { std::unique_ptr MakeNewOptimizer( const string& optimizer) const; - bool IsSingleThreadedExecutor() const; + // When grappler should lower control flow to V1 switch/merge style nodes. + bool LowerControlFlow() const; // Initialize active optimizers from RewriterConfig toggles. Status InitializeOptimizers( diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 91c703e3723..ce56654c883 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -7523,6 +7523,7 @@ exports_files([ "cwise_op_gpu_sigmoid.cu.cc", "cwise_op_gpu_sin.cu.cc", "cwise_op_gpu_sqrt.cu.cc", + "cwise_op_gpu_square.cu.cc", "cwise_op_gpu_squared_difference.cu.cc", "cwise_op_gpu_sub.cu.cc", "cwise_op_gpu_tanh.cu.cc", diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h index 094e2cf9cf1..568e8ab6db0 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.h +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -193,27 +193,19 @@ struct DepthwiseInputCopyOp { const int64 padded_filter_inner_dim_size, const int64 out_r, const int64 out_c, const T* input, T* input_buffer) { typedef typename Eigen::internal::packet_traits::type Packet; - static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); + static const int64 kPacketSize = Eigen::internal::packet_traits::size; + const int64 kDepth = args.depth_multiplier; // Calculate vectorized and scalar (residual) lengths for 'in_depth'. const int64 input_vectorized_size = (args.in_depth / kPacketSize) * kPacketSize; - const int64 input_scalar_size = args.in_depth % kPacketSize; - - // Calculate vectorized and scalar (residual) lengths for - // 'depth_multiplier'. This is used to efficiently replicate data for - // when 'depth_multiplier' > kPacketSize. - const int64 dm_vectorized_size = - (args.depth_multiplier / kPacketSize) * kPacketSize; - const int64 dm_scalar_size = args.depth_multiplier % kPacketSize; + const int64 input_scalar_size = args.in_depth - input_vectorized_size; // Calculate output padding length. const int64 output_scalar_size = args.out_depth % kPacketSize; const int64 output_pad_size = output_scalar_size > 0 ? kPacketSize - output_scalar_size : 0; - const int64 replicated_packet_size = kPacketSize * args.depth_multiplier; - // Iterate through all rows x cols reading 'in_depth' from 'input' and // replicating by 'depth_multiplier' into 'input_buffer' (otherwise // zero-padding input buffer as needed). @@ -221,60 +213,126 @@ struct DepthwiseInputCopyOp { const int64 in_r_start = out_r * args.stride - args.pad_rows; const int64 in_c_start = out_c * args.stride - args.pad_cols; - for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) { - const int64 in_r = in_r_start + f_r; + // TODO: add a ploaddup variant for depth == 2 if needed. + if (kDepth > 1 && kDepth <= kPacketSize) { + for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64 in_r = in_r_start + f_r; - for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) { - const int64 in_c = in_c_start + f_c; + for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64 in_c = in_c_start + f_c; - if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && - in_c < args.in_cols) { - auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth; - // Copy vectorized portion of inner dimension. - for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) { - auto v = Eigen::internal::ploadu(in + d); - for (int dm = 0; dm < args.depth_multiplier; ++dm) { - Eigen::internal::pscatter(in_buf + dm, v, - args.depth_multiplier); + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + int64 limit = args.in_depth; + // This will overwrite up to kPacketSize next elements, + // this is ok on all iterations except the last one, since + // we will write correct values on a next iteration. + if (f_c == args.filter_cols - 1) { + limit -= (kPacketSize - kDepth) / kDepth + 1; + if (limit < 0) { + limit = 0; + } + } + // Copy vectorized portion of inner dimension. + for (int64 d = 0; d < limit; d++) { + const auto p = Eigen::internal::pset1(in[d]); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kDepth; } - in_buf += replicated_packet_size; - } - // Copy scalar portion of inner dimension. - for (int64 d = 0; d < input_scalar_size; ++d) { - T v = in[input_vectorized_size + d]; - const int64 base = d * args.depth_multiplier; - if (dm_vectorized_size > 0) { - // Copy vectorized portion of replicated output. - // This branch is only taken if 'args.depth_multiplier' is - // vectorizable (i.e. args.depth_multiplier >= register width). - auto p = Eigen::internal::pset1(v); + // Copy the scalar portion. + for (int64 d = limit; d < args.in_depth; d++) { + const auto value = in[d]; + for (int64 dm = 0; dm < kDepth; dm++) { + in_buf[dm] = value; + } + in_buf += kDepth; + } + + // Pad the remainder of the output to vector register boundary. + for (int64 d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } else if (kDepth > kPacketSize) { + // Calculate vectorized and scalar (residual) lengths for + // 'depth_multiplier'. This is used to efficiently replicate data for + // when 'depth_multiplier' > kPacketSize. + const int64 dm_vectorized_size = (kDepth / kPacketSize) * kPacketSize; + + for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64 in_r = in_r_start + f_r; + + for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64 in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + // Copy vectorized portion of inner dimension. + for (int64 d = 0; d < args.in_depth; d++) { + const auto p = Eigen::internal::pset1(in[d]); for (int64 dm = 0; dm < dm_vectorized_size; dm += kPacketSize) { - Eigen::internal::pstoreu(in_buf + base + dm, p); - } - // Copy scalar portion of replicated output. - for (int64 dm = 0; dm < dm_scalar_size; ++dm) { - in_buf[base + dm_vectorized_size + dm] = v; - } - } else { - // Depth multiplier is less than one packet: scalar copy. - for (int dm = 0; dm < args.depth_multiplier; ++dm) { - in_buf[base + dm] = v; + Eigen::internal::pstoreu(in_buf + dm, p); } + // Overlapping store for the remainder. + Eigen::internal::pstoreu(in_buf + kDepth - kPacketSize, p); + in_buf += kDepth; } + // Pad the remainder of the output to vector register boundary. + for (int64 d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; } - in_buf += input_scalar_size * args.depth_multiplier; + } + } + } else if (kDepth == 1) { + for (int64 f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64 in_r = in_r_start + f_r; - // Pad the remainder of the output to vector register boundary. - for (int64 d = 0; d < output_pad_size; ++d) { - in_buf[d] = static_cast(0); + for (int64 f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64 in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) { + const auto p = Eigen::internal::ploadu(in + d); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kPacketSize; + } + for (int64 d = 0; d < input_scalar_size; ++d) { + T v = in[input_vectorized_size + d]; + in_buf[d] = v; + } + in_buf += input_scalar_size; + + // Pad the remainder of the output to vector register boundary. + for (int64 d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; } - in_buf += output_pad_size; - - } else { - // Zero pad. - memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); - in_buf += padded_filter_inner_dim_size; } } } diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 047782f5ff9..0c0e51f6e2f 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -77,6 +77,20 @@ filegroup( compatible_with = get_compatible_with_cloud(), ) +filegroup( + name = "unary_kernel_srcs", + srcs = if_mlir_unranked_kernels_enabled( + if_false = [ + "cwise_op_gpu_abs.cc", + "cwise_op_gpu_base.cc", + "cwise_op_gpu_base.h", + "cwise_op_gpu_tanh.cc", + ], + if_true = [":unary_unranked_kernel_srcs"], + ), + compatible_with = get_compatible_with_cloud(), +) + cc_library( name = "unranked_op_gpu_base", srcs = ["unranked_op_gpu_base.cc"], @@ -96,45 +110,68 @@ cc_library( tf_kernel_library( name = "cwise_unary_op", - srcs = [":unary_unranked_kernel_srcs"], + srcs = [":unary_kernel_srcs"], tags = [ "manual", ], - deps = [ - # Technically we only need to depend on the kernel libraries for the - # unranked kernels which are enabled by default. But this would - # make our BUILD target structure uglier. We already need to make - # sure that those targets can be built, so it should not hurt to - # link them in even if they are currently not needed yet. - ":abs_unranked_kernels", - ":ceil_unranked_kernels", - ":conj_unranked_kernels", - ":cos_unranked_kernels", - ":exp_unranked_kernels", - ":floor_unranked_kernels", - ":imag_unranked_kernels", - ":is_inf_unranked_kernels", - ":log_unranked_kernels", - ":logical_not_unranked_kernels", - ":real_unranked_kernels", - ":rsqrt_unranked_kernels", - ":sign_unranked_kernels", - ":sin_unranked_kernels", - ":sqrt_unranked_kernels", - ":tanh_unranked_kernels", - ":unranked_op_gpu_base", - "//third_party/eigen3", - ], + deps = if_mlir_unranked_kernels_enabled( + if_false = [ + ":abs_kernels", + ":tanh_kernels", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:stream_executor", + ], + if_true = [ + # Technically we only need to depend on the kernel libraries for the + # unranked kernels which are enabled by default. But this would + # make our BUILD target structure uglier. We already need to make + # sure that those targets can be built, so it should not hurt to + # link them in even if they are currently not needed yet. + ":abs_unranked_kernels", + ":ceil_unranked_kernels", + ":conj_unranked_kernels", + ":cos_unranked_kernels", + ":exp_unranked_kernels", + ":floor_unranked_kernels", + ":imag_unranked_kernels", + ":is_inf_unranked_kernels", + ":log_unranked_kernels", + ":logical_not_unranked_kernels", + ":real_unranked_kernels", + ":rsqrt_unranked_kernels", + ":sign_unranked_kernels", + ":sin_unranked_kernels", + ":sqrt_unranked_kernels", + ":tanh_unranked_kernels", + ":unranked_op_gpu_base", + "//third_party/eigen3", + ], + ), ) tf_kernel_library( name = "cwise_binary_op", - srcs = ["unranked_gpu_add.cc"], + srcs = [ + "unranked_op_gpu_add.cc", + ], tags = [ "manual", ], deps = [ ":add_v2_unranked_kernels", + ":equal_unranked_kernels", + ":greater_equal_unranked_kernels", + ":greater_unranked_kernels", + ":less_equal_unranked_kernels", + ":less_unranked_kernels", + ":maximum_unranked_kernels", + ":minimum_unranked_kernels", + ":not_equal_unranked_kernels", ":unranked_op_gpu_base", "//third_party/eigen3", ], @@ -180,9 +217,9 @@ tf_cuda_cc_test( ) tf_cuda_cc_test( - name = "gpu_add_test", + name = "gpu_binary_ops_test", size = "small", - srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]), + srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_binary_ops_test.cc"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # b/173033461 ], @@ -195,8 +232,10 @@ tf_cuda_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/common_runtime:device", "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:ops_testutil", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -244,7 +283,7 @@ gen_kernel_library( "f32", "f64", ], - unroll_factors = "4", + unroll_factors = "2", ) gen_kernel_library( @@ -289,7 +328,7 @@ gen_kernel_library( generate_unranked = True, tile_size = "256", types = ["i1"], - unroll_factors = "4", + unroll_factors = "1", ) gen_kernel_library( @@ -329,28 +368,83 @@ gen_kernel_library( "f64", "i64", ], - # TODO(b/174543802): Enable once fusion heursitics is better. + # TODO(b/174543802): Enable once fusion heuristics is better. # unroll_factors = "4", ) -gen_kernel_library( - name = "equal", - generate_ranked = False, - generate_unranked = True, - tile_size = "256,1,1", - types = [ - "f16", - "f32", - "f64", - "i1", - "i8", - "i16", - "i32", - "i64", - ], - # TODO(b/174543802): Enable once fusion heursitics is better. - # unroll_factors = "4", -) +[ + gen_kernel_library( + name = name, + generate_ranked = False, + generate_unranked = True, + tile_size = "256,1,1", + types = [ + "f16", + "f32", + "f64", + "i1", + "i8", + "i16", + "i32", + "i64", + ], + # TODO(b/174543802): Enable once fusion heuristics is better. + # unroll_factors = "4", + ) + for name in [ + "equal", + "not_equal", + ] +] + +[ + gen_kernel_library( + name = name, + generate_ranked = False, + generate_unranked = True, + tile_size = "256,1,1", + types = [ + "f16", + "f32", + "f64", + "i8", + "i16", + "i32", + "i64", + ], + # TODO(b/174543802): Enable once fusion heuristics is better. + # unroll_factors = "4", + ) + for name in [ + "less", + "less_equal", + "greater", + "greater_equal", + ] +] + +[ + gen_kernel_library( + name = name, + generate_ranked = False, + generate_unranked = True, + tile_size = "256,1,1", + types = [ + "f16", + "f32", + "f64", + "i16", + "i32", + "i64", + ], + # TODO(b/174543802): Enable once fusion heuristics is better. + # unroll_factors = "4", + ) + for name in [ + "maximum", + "minimum", + ] +] # Kernels that support all floating-point types. [ diff --git a/tensorflow/core/kernels/mlir_generated/build_test.sh b/tensorflow/core/kernels/mlir_generated/build_test.sh index a0748a9d0d8..0fcb8a3a130 100755 --- a/tensorflow/core/kernels/mlir_generated/build_test.sh +++ b/tensorflow/core/kernels/mlir_generated/build_test.sh @@ -24,7 +24,7 @@ OUTPUT_FILE="${TEST_TMPDIR}/output.mlir" INPUT="$2" # Do something -${TF_TO_KERNEL} --input=${INPUT} --output=${OUTPUT_FILE} --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75 || die "Failed to generate kernel" +${TF_TO_KERNEL} --input=${INPUT} --output=${OUTPUT_FILE} --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75 "${@:3}" || die "Failed to generate kernel" # Check something [ -s ${OUTPUT_FILE} ] || die "output file was empty" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc deleted file mode 100644 index b518aff7a03..00000000000 --- a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -class GpuAddTest : public OpsTestBase { - protected: - void SetUp() override { - std::unique_ptr device_gpu( - tensorflow::DeviceFactory::NewDevice("GPU", {}, - "/job:a/replica:0/task:0")); - SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu)); - } - - template - void SetAddOp(std::vector input_1, TensorShape shape_1, - std::vector input_2, TensorShape shape_2) { - TF_ASSERT_OK(NodeDefBuilder("add_op", "AddV2") - .Input(FakeInput(DataTypeToEnum::v())) - .Input(FakeInput(DataTypeToEnum::v())) - .Attr("T", DataTypeToEnum::v()) - .Finalize(node_def())); - - TF_ASSERT_OK(InitOp()); - inputs_.clear(); - AddInputFromArray(shape_1, input_1); - AddInputFromArray(shape_2, input_2); - } - - template - void RunAndCompareAddOp(std::vector input_1, TensorShape shape_1, - std::vector input_2, TensorShape shape_2, - std::vector output, TensorShape output_shape) { - SetAddOp(input_1, shape_1, input_2, shape_2); - TF_ASSERT_OK(RunOpKernel()); - Tensor expected_tensor(allocator(), DataTypeToEnum::value, output_shape); - test::FillValues(&expected_tensor, output); - test::ExpectEqual(expected_tensor, *GetOutput(0)); - } - - template - void TestBroadcastingExpandAddOp() { - auto input_1 = {static_cast(10)}; - auto input_2 = {static_cast(1), static_cast(2), static_cast(3), - static_cast(4), static_cast(5), static_cast(6)}; - std::vector expected{ - static_cast(11), static_cast(12), static_cast(13), - static_cast(14), static_cast(15), static_cast(16), - }; - auto expected_shape = TensorShape({6}); - RunAndCompareAddOp(input_1, TensorShape({1}), input_2, - TensorShape({6}), expected, - expected_shape); - } - - template - void TestBroadcastingInDimAddOp() { - auto input_1 = {static_cast(10), static_cast(20), static_cast(30)}; - auto input_2 = {static_cast(1), static_cast(2), static_cast(3), - static_cast(4), static_cast(5), static_cast(6)}; - std::vector expected{ - static_cast(11), static_cast(22), static_cast(33), - static_cast(14), static_cast(25), static_cast(36), - }; - auto expected_shape = TensorShape({2, 3}); - RunAndCompareAddOp(input_1, TensorShape({3}), input_2, - TensorShape({2, 3}), expected, - expected_shape); - } - - template - void TestBroadcastingAddOp() { - auto input_1 = {static_cast(10), static_cast(20)}; - auto input_2 = {static_cast(1), static_cast(2), static_cast(3)}; - std::vector expected{ - static_cast(11), static_cast(12), static_cast(13), - static_cast(21), static_cast(22), static_cast(23), - }; - auto expected_shape = TensorShape({2, 3}); - RunAndCompareAddOp(input_1, TensorShape({2, 1}), input_2, - TensorShape({3}), expected, - expected_shape); - } - - template - void RunAddOp() { - auto input_1 = { - static_cast(-std::numeric_limits::infinity()), - static_cast(-0.1), - static_cast(-0.0), - static_cast(0.0), - static_cast(0.1), - static_cast(std::numeric_limits::infinity())}; - auto input_2 = { - static_cast(-std::numeric_limits::infinity()), - static_cast(-0.1), - static_cast(-0.0), - static_cast(0.0), - static_cast(0.1), - static_cast(std::numeric_limits::infinity())}; - std::vector expected; - for (const T& inp : input_2) { - expected.push_back(static_cast(static_cast(inp) + - static_cast(inp))); - } - RunAndCompareAddOp(input_1, TensorShape{2, 3}, input_2, - TensorShape{2, 3}, expected, - TensorShape{2, 3}); - } - - template - void TestEqualShapesAddOp() { - auto input_1 = { - static_cast(-std::numeric_limits::infinity()), - static_cast(-0.1), - static_cast(-0.0), - static_cast(0.0), - static_cast(0.1), - static_cast(std::numeric_limits::infinity())}; - auto input_2 = { - static_cast(-std::numeric_limits::infinity()), - static_cast(-0.1), - static_cast(-0.0), - static_cast(0.0), - static_cast(0.1), - static_cast(std::numeric_limits::infinity())}; - std::vector expected; - for (const T& inp : input_2) { - expected.push_back(static_cast(static_cast(inp) + - static_cast(inp))); - } - RunAndCompareAddOp(input_1, TensorShape{2, 3}, input_2, - TensorShape{2, 3}, expected, - TensorShape{2, 3}); - } - - template - void TestOneIsScalarAddOp() { - auto input_1 = static_cast(42); - auto input_2 = { - static_cast(-std::numeric_limits::infinity()), - static_cast(-0.1), - static_cast(-0.0), - static_cast(0.0), - static_cast(0.1), - static_cast(std::numeric_limits::infinity())}; - std::vector expected; - for (const T& inp : input_2) { - expected.push_back(static_cast(static_cast(input_1) + - static_cast(inp))); - } - RunAndCompareAddOp({input_1}, TensorShape{}, input_2, - TensorShape{2, 3}, expected, - TensorShape{2, 3}); - } - - template - void TestIncompatibleShapes() { - auto input_1 = {static_cast(-0.1), static_cast(-0.0), - static_cast(0.0)}; - auto input_2 = {static_cast(-0.1), static_cast(0.0)}; - - SetAddOp(input_1, TensorShape{3}, input_2, TensorShape{2}); - auto status = RunOpKernel(); - EXPECT_FALSE(status.ok()); - EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); - } - - template - void TestEmptyShapeWithBroadcastingAddOp() { - TensorShape input_shape_a{2, 0, 1}; - TensorShape input_shape_b{2, 0, 5}; - TensorShape expected_shape{2, 0, 5}; - std::vector empty_input = {}; - RunAndCompareAddOp(empty_input, input_shape_a, empty_input, - input_shape_b, empty_input, - expected_shape); - RunAndCompareAddOp(empty_input, input_shape_b, empty_input, - input_shape_a, empty_input, - expected_shape); - } -}; - -TEST_F(GpuAddTest, AddFloat) { RunAddOp(); } -TEST_F(GpuAddTest, AddDouble) { RunAddOp(); } -TEST_F(GpuAddTest, AddHalf) { RunAddOp(); } -TEST_F(GpuAddTest, AddInt64) { RunAddOp(); } - -TEST_F(GpuAddTest, AddEqShapesFloat) { TestEqualShapesAddOp(); } -TEST_F(GpuAddTest, AddEqShapesDouble) { TestEqualShapesAddOp(); } -TEST_F(GpuAddTest, AddEqShapesHalf) { - TestEqualShapesAddOp(); -} -TEST_F(GpuAddTest, AddEqShapesInt64) { TestEqualShapesAddOp(); } - -TEST_F(GpuAddTest, AddScalarFloat) { TestOneIsScalarAddOp(); } -TEST_F(GpuAddTest, AddScalarDouble) { TestOneIsScalarAddOp(); } -TEST_F(GpuAddTest, AddScalarHalf) { - TestOneIsScalarAddOp(); -} -TEST_F(GpuAddTest, AddScalarInt64) { TestOneIsScalarAddOp(); } - -TEST_F(GpuAddTest, BCastExpandAddFloat) { - TestBroadcastingExpandAddOp(); -} -TEST_F(GpuAddTest, BCastExpandAddDouble) { - TestBroadcastingExpandAddOp(); -} -TEST_F(GpuAddTest, BCastExpandAddHalf) { - TestBroadcastingExpandAddOp(); -} -TEST_F(GpuAddTest, BCastExpandAddInt64) { - TestBroadcastingExpandAddOp(); -} - -TEST_F(GpuAddTest, BCastInDimAddFloat) { TestBroadcastingInDimAddOp(); } -TEST_F(GpuAddTest, BCastInDimAddDouble) { - TestBroadcastingInDimAddOp(); -} -TEST_F(GpuAddTest, BCastInDimAddHalf) { - TestBroadcastingInDimAddOp(); -} -TEST_F(GpuAddTest, BCastInDimAddInt64) { TestBroadcastingInDimAddOp(); } - -TEST_F(GpuAddTest, BCastAddFloat) { TestBroadcastingAddOp(); } -TEST_F(GpuAddTest, BCastAddDouble) { TestBroadcastingAddOp(); } -TEST_F(GpuAddTest, BCastAddHalf) { - TestBroadcastingAddOp(); -} -TEST_F(GpuAddTest, BCastAddInt64) { TestBroadcastingAddOp(); } - -TEST_F(GpuAddTest, IncompatibleShapes) { TestIncompatibleShapes(); } - -TEST_F(GpuAddTest, EmptyShapeBCastAddFloat) { - TestEmptyShapeWithBroadcastingAddOp(); -} -TEST_F(GpuAddTest, EmptyShapeBCastAddDouble) { - TestEmptyShapeWithBroadcastingAddOp(); -} - -// TEST_F(GpuAddTest, AddV2Half) { RunAddOp(); } -} // namespace -} // end namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc new file mode 100644 index 00000000000..1448a86322a --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -0,0 +1,404 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// Tests are parametrized with the kernel name, the input data type and the +// output data type. +struct BinaryTestParam { + std::string op_name; + DataType input_type; + DataType output_type; + BinaryTestParam(const std::string& name, DataType input, DataType output) + : op_name(name), input_type(input), output_type(output) {} +}; + +// To add additional tests for other kernels, search for PLACEHOLDER in this +// file. + +class ParametricGpuBinaryOpsTest + : public OpsTestBase, + public ::testing::WithParamInterface { + protected: + void SetUp() override { + std::unique_ptr device_gpu( + tensorflow::DeviceFactory::NewDevice("GPU", {}, + "/job:a/replica:0/task:0")); + SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu)); + } + + template + void SetOp(const absl::InlinedVector& input_1, + const TensorShape& shape_1, + const absl::InlinedVector& input_2, + const TensorShape& shape_2) { + TF_ASSERT_OK(NodeDefBuilder("some_name", GetParam().op_name) + .Input(FakeInput(DataTypeToEnum::v())) + .Input(FakeInput(DataTypeToEnum::v())) + .Attr("T", DataTypeToEnum::v()) + .Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + inputs_.clear(); + AddInputFromArray(shape_1, input_1); + AddInputFromArray(shape_2, input_2); + } + + template + void RunAndCompare(const absl::InlinedVector& input_1, + const TensorShape& shape_1, + const absl::InlinedVector& input_2, + const TensorShape& shape_2, + const absl::InlinedVector& output, + const TensorShape& output_shape) { + SetOp(input_1, shape_1, input_2, shape_2); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected_tensor(allocator(), DataTypeToEnum::value, + output_shape); + test::FillValues(&expected_tensor, output); + test::ExpectEqual(expected_tensor, *GetOutput(0)); + } + + template + void TestBroadcastingExpand() { + auto input_1 = absl::InlinedVector{static_cast(10)}; + auto input_2 = absl::InlinedVector{ + static_cast(1), static_cast(2), static_cast(3), + static_cast(4), static_cast(5), static_cast(6)}; + absl::InlinedVector expected{ + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[0]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[1]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[2]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[3]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[4]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[5]))), + }; + auto expected_shape = TensorShape({6}); + RunAndCompare(input_1, TensorShape({1}), input_2, + TensorShape({6}), expected, + expected_shape); + } + + template + void TestBroadcastingInDim() { + auto input_1 = absl::InlinedVector{ + static_cast(10), static_cast(20), static_cast(30)}; + auto input_2 = absl::InlinedVector{ + static_cast(1), static_cast(2), static_cast(3), + static_cast(4), static_cast(5), static_cast(6)}; + absl::InlinedVector expected{ + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[0]))), + static_cast(Expected( + static_cast(input_1[1]), + static_cast(input_2[1]))), + static_cast(Expected( + static_cast(input_1[2]), + static_cast(input_2[2]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[3]))), + static_cast(Expected( + static_cast(input_1[1]), + static_cast(input_2[4]))), + static_cast(Expected( + static_cast(input_1[2]), + static_cast(input_2[5]))), + }; + auto expected_shape = TensorShape({2, 3}); + RunAndCompare(input_1, TensorShape({3}), input_2, + TensorShape({2, 3}), expected, + expected_shape); + } + + template + void TestBroadcasting() { + auto input_1 = + absl::InlinedVector{static_cast(10), static_cast(20)}; + auto input_2 = absl::InlinedVector{ + static_cast(1), static_cast(2), static_cast(3)}; + absl::InlinedVector expected{ + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[0]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[1]))), + static_cast(Expected( + static_cast(input_1[0]), + static_cast(input_2[2]))), + static_cast(Expected( + static_cast(input_1[1]), + static_cast(input_2[0]))), + static_cast(Expected( + static_cast(input_1[1]), + static_cast(input_2[1]))), + static_cast(Expected( + static_cast(input_1[1]), + static_cast(input_2[2]))), + }; + auto expected_shape = TensorShape({2, 3}); + RunAndCompare(input_1, TensorShape({2, 1}), input_2, + TensorShape({3}), expected, + expected_shape); + } + + template + void RunOp() { + auto input_1 = { + static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), + static_cast(-0.0), + static_cast(0.0), + static_cast(0.1), + static_cast(std::numeric_limits::infinity())}; + auto input_2 = { + static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), + static_cast(-0.0), + static_cast(0.0), + static_cast(0.1), + static_cast(std::numeric_limits::infinity())}; + absl::InlinedVector expected; + for (const T& inp : input_2) { + expected.push_back(static_cast(Expected( + static_cast(inp), static_cast(inp)))); + } + RunAndCompare(input_1, TensorShape{2, 3}, input_2, + TensorShape{2, 3}, expected, + TensorShape{2, 3}); + } + + template + void TestEqualShapes() { + auto input_1 = { + static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), + static_cast(-0.0), + static_cast(0.0), + static_cast(0.1), + static_cast(std::numeric_limits::infinity())}; + auto input_2 = { + static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), + static_cast(-0.0), + static_cast(0.0), + static_cast(0.1), + static_cast(std::numeric_limits::infinity())}; + absl::InlinedVector expected; + for (const T& inp : input_2) { + expected.push_back(static_cast(Expected( + static_cast(inp), static_cast(inp)))); + } + RunAndCompare(input_1, TensorShape{2, 3}, input_2, + TensorShape{2, 3}, expected, + TensorShape{2, 3}); + } + + template + void TestOneIsScalar() { + auto input_1 = static_cast(42); + auto input_2 = { + static_cast(-std::numeric_limits::infinity()), + static_cast(-0.1), + static_cast(-0.0), + static_cast(0.0), + static_cast(0.1), + static_cast(std::numeric_limits::infinity())}; + absl::InlinedVector expected; + for (const T& inp : input_2) { + expected.push_back(static_cast(Expected( + static_cast(input_1), static_cast(inp)))); + } + RunAndCompare({input_1}, TensorShape{}, input_2, + TensorShape{2, 3}, expected, + TensorShape{2, 3}); + } + + template + void TestIncompatibleShapes() { + auto input_1 = {static_cast(-0.1), static_cast(-0.0), + static_cast(0.0)}; + auto input_2 = {static_cast(-0.1), static_cast(0.0)}; + + SetOp(input_1, TensorShape{3}, input_2, TensorShape{2}); + auto status = RunOpKernel(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + } + + template + void TestEmptyShapeWithBroadcasting() { + TensorShape input_shape_a{2, 0, 1}; + TensorShape input_shape_b{2, 0, 5}; + TensorShape expected_shape{2, 0, 5}; + absl::InlinedVector empty_input = {}; + absl::InlinedVector expected_result = {}; + RunAndCompare(empty_input, input_shape_a, + empty_input, input_shape_b, + expected_result, expected_shape); + RunAndCompare(empty_input, input_shape_b, + empty_input, input_shape_a, + expected_result, expected_shape); + } + + template + BaselineOutT Expected(BaselineType lhs, BaselineType rhs) { + if (GetParam().op_name == "AddV2") { + return static_cast(lhs + rhs); + } + // Add the logic for creating expected values for the kernel you want to + // test here. + // + LOG(FATAL) << "Cannot generate expected result for op " + << GetParam().op_name; + return static_cast(lhs); + } +}; + +std::vector GetBinaryTestParameters() { + std::vector parameters; + for (DataType dt : + std::vector{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_INT64}) { + parameters.emplace_back("AddV2", dt, dt); + } + // Add the parameters (kernel name and data types to test) here. + // + return parameters; +} + +#define GENERATE_DATA_TYPE_SWITCH_CASE(dt, nt, code) \ + switch (dt) { \ + case DT_FLOAT: { \ + using nt = EnumToDataType::Type; \ + code; \ + break; \ + } \ + case DT_DOUBLE: { \ + using nt = EnumToDataType::Type; \ + code; \ + break; \ + } \ + case DT_HALF: { \ + using nt = EnumToDataType::Type; \ + code; \ + break; \ + } \ + case DT_INT64: { \ + using nt = EnumToDataType::Type; \ + code; \ + break; \ + } \ + case DT_BOOL: { \ + using nt = EnumToDataType::Type; \ + code; \ + break; \ + } \ + default: \ + LOG(FATAL) << "Unsupported type: " << DataType_Name(dt); \ + } + +#define COMMA , + +#define GENERATE_TEST_CALL(test_fn) \ + GENERATE_DATA_TYPE_SWITCH_CASE( \ + GetParam().input_type, NativeInT, \ + GENERATE_DATA_TYPE_SWITCH_CASE( \ + GetParam().output_type, NativeOutT, \ + if (GetParam().input_type == DT_HALF) { \ + if (GetParam().output_type == DT_HALF) { \ + test_fn(); \ + } else { \ + test_fn< \ + NativeInT COMMA float COMMA NativeOutT COMMA NativeOutT>(); \ + } \ + } else { \ + test_fn(); \ + })) + +TEST_P(ParametricGpuBinaryOpsTest, RunOp) { GENERATE_TEST_CALL(RunOp); } + +TEST_P(ParametricGpuBinaryOpsTest, EqShapes) { + GENERATE_TEST_CALL(TestEqualShapes); +} + +TEST_P(ParametricGpuBinaryOpsTest, Scalar) { + GENERATE_TEST_CALL(TestOneIsScalar); +} + +TEST_P(ParametricGpuBinaryOpsTest, BCastExpand) { + GENERATE_TEST_CALL(TestBroadcastingExpand); +} + +TEST_P(ParametricGpuBinaryOpsTest, BCastInDim) { + GENERATE_TEST_CALL(TestBroadcastingInDim); +} + +TEST_P(ParametricGpuBinaryOpsTest, BCast) { + GENERATE_TEST_CALL(TestBroadcasting); +} + +TEST_P(ParametricGpuBinaryOpsTest, IncompatibleShapes) { + GENERATE_TEST_CALL(TestIncompatibleShapes); +} + +TEST_P(ParametricGpuBinaryOpsTest, EmptyShapeBCast) { + GENERATE_TEST_CALL(TestEmptyShapeWithBroadcasting); +} + +INSTANTIATE_TEST_SUITE_P(GpuBinaryOpsTests, ParametricGpuBinaryOpsTest, + ::testing::ValuesIn(GetBinaryTestParameters())); +} // namespace +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl new file mode 100644 index 00000000000..47010eec805 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater.mlir.tmpl @@ -0,0 +1,6 @@ +func @Greater_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Greater"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl new file mode 100644 index 00000000000..63c0ce9caa2 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/greater_equal.mlir.tmpl @@ -0,0 +1,6 @@ +func @GreaterEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.GreaterEqual"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl new file mode 100644 index 00000000000..59496dc7b16 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less.mlir.tmpl @@ -0,0 +1,6 @@ +func @Less_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Less"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl new file mode 100644 index 00000000000..245f27abf9a --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/less_equal.mlir.tmpl @@ -0,0 +1,6 @@ +func @LessEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.LessEqual"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl new file mode 100644 index 00000000000..c917b9a6c0d --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/maximum.mlir.tmpl @@ -0,0 +1,6 @@ +func @Maximum_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Maximum"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl new file mode 100644 index 00000000000..6d8987b0ce3 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/minimum.mlir.tmpl @@ -0,0 +1,6 @@ +func @Minimum_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Minimum"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl new file mode 100644 index 00000000000..8efef8bc2f2 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/not_equal.mlir.tmpl @@ -0,0 +1,6 @@ +func @NotEqual_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.NotEqual"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/unranked_gpu_add.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_add.cc similarity index 100% rename from tensorflow/core/kernels/mlir_generated/unranked_gpu_add.cc rename to tensorflow/core/kernels/mlir_generated/unranked_op_gpu_add.cc diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h index 81d94e0b4c8..3c10c2e4103 100644 --- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h @@ -25,6 +25,18 @@ limitations under the License. namespace tensorflow { +// A type-erased version of the UnrankedMemRefType to allow it to be used +// as the return type of an extern "C" function on windows. +struct UntypedUnrankedMemRefType { + int64_t rank; + void* descriptor; +}; + +template +UnrankedMemRefType ConvertToTyped(UntypedUnrankedMemRefType desc) { + return {desc.rank, desc.descriptor}; +} + // Returns a pointer to an allocated MlirTensorBuffer that takes ownership of // pre-allocated memory. TensorBuffer* GetMlirTensorBuffer(const void* ptr, size_t size, @@ -157,25 +169,26 @@ class MlirUnrankedOp : public OpKernel { GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ REGISTER_KERNEL(tf_op, mlir_type, data_type) -#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ - extern "C" ::UnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \ - tensorflow::OpKernelContext * ctx, \ - const ::UnrankedMemRefType* arg1, \ - const ::UnrankedMemRefType* arg2); \ - \ - namespace { \ - class MlirUnranked##tf_op##mlir_type##Op \ - : public MlirUnrankedOp { \ - public: \ - using MlirUnrankedOp::MlirUnrankedOp; \ - \ - static ::UnrankedMemRefType Invoke( \ - OpKernelContext* ctx, \ - llvm::ArrayRef<::UnrankedMemRefType> args) { \ - return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0], &args[1]); \ - } \ - }; \ +#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ + extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \ + tensorflow::OpKernelContext * ctx, \ + const ::UnrankedMemRefType* arg1, \ + const ::UnrankedMemRefType* arg2); \ + \ + namespace { \ + class MlirUnranked##tf_op##mlir_type##Op \ + : public MlirUnrankedOp { \ + public: \ + using MlirUnrankedOp::MlirUnrankedOp; \ + \ + static ::UnrankedMemRefType Invoke( \ + OpKernelContext* ctx, \ + llvm::ArrayRef<::UnrankedMemRefType> args) { \ + return ConvertToTyped( \ + MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0], &args[1])); \ + } \ + }; \ } #define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, \ @@ -186,26 +199,27 @@ class MlirUnrankedOp : public OpKernel { #define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, data_type) -#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, \ - input_data_type) \ - extern "C" ::UnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \ - tensorflow::OpKernelContext * ctx, \ - const ::UnrankedMemRefType* arg); \ - \ - namespace { \ - class MlirUnranked##tf_op##mlir_type##Op \ - : public MlirUnrankedOp { \ - public: \ - using MlirUnrankedOp::MlirUnrankedOp; \ - \ - static ::UnrankedMemRefType Invoke( \ - OpKernelContext* ctx, \ - llvm::ArrayRef<::UnrankedMemRefType> args) { \ - return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0]); \ - } \ - }; \ +#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, \ + input_data_type) \ + extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \ + tensorflow::OpKernelContext * ctx, \ + const ::UnrankedMemRefType* arg); \ + \ + namespace { \ + class MlirUnranked##tf_op##mlir_type##Op \ + : public MlirUnrankedOp { \ + public: \ + using MlirUnrankedOp::MlirUnrankedOp; \ + \ + static ::UnrankedMemRefType Invoke( \ + OpKernelContext* ctx, \ + llvm::ArrayRef<::UnrankedMemRefType> args) { \ + return ConvertToTyped( \ + MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0])); \ + } \ + }; \ } } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index 44bcab497ee..6fab5f1f5ad 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -18,7 +18,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/sparse_xent_op.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -123,8 +122,6 @@ REGISTER(CPU, float, int32) REGISTER(CPU, float, int64) REGISTER(CPU, double, int32) REGISTER(CPU, double, int64) -REGISTER(CPU, bfloat16, int32) -REGISTER(CPU, bfloat16, int64) REGISTER(CPU, Eigen::half, int32) REGISTER(CPU, Eigen::half, int64) diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc index f095f2e2cf7..85a5cd3befc 100644 --- a/tensorflow/core/kernels/sparse_xent_op_test.cc +++ b/tensorflow/core/kernels/sparse_xent_op_test.cc @@ -23,9 +23,9 @@ limitations under the License. namespace tensorflow { -static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) { +static Graph* SparseXent(int batch_size, int num_classes) { Graph* g = new Graph(OpRegistry::Global()); - Tensor logits(value_type, TensorShape({batch_size, num_classes})); + Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes})); logits.flat().setRandom(); Tensor labels(DT_INT64, TensorShape({batch_size})); std::random_device rd; @@ -41,45 +41,44 @@ static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) { return g; } -#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE) \ - static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE( \ +#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ + static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE( \ ::testing::benchmark::State& state) { \ - test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE), \ + test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS), \ /*old_benchmark_api*/ false) \ .Run(state); \ state.SetItemsProcessed(static_cast(state.iterations()) * BATCH * \ CLASS); \ } \ - BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE); - -#define BM_SPARSE_XENT_DEV_CPU(DTYPE) \ - BM_SparseXentDev(8, 1000000, cpu, DTYPE); \ - BM_SparseXentDev(16, 10000, cpu, DTYPE); \ - BM_SparseXentDev(16, 100000, cpu, DTYPE); \ - BM_SparseXentDev(32, 10000, cpu, DTYPE); \ - BM_SparseXentDev(32, 100000, cpu, DTYPE); \ - BM_SparseXentDev(64, 10000, cpu, DTYPE); \ - BM_SparseXentDev(64, 100000, cpu, DTYPE); - -// CPU -BM_SPARSE_XENT_DEV_CPU(DT_FLOAT); -BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16); + BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE); /// The representative tests for ptb_word on GPU #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT); +BM_SparseXentDev(8, 1000000, gpu); -BM_SparseXentDev(16, 10000, gpu, DT_FLOAT); -BM_SparseXentDev(16, 30000, gpu, DT_FLOAT); -BM_SparseXentDev(16, 100000, gpu, DT_FLOAT); +BM_SparseXentDev(16, 10000, gpu); +BM_SparseXentDev(16, 30000, gpu); +BM_SparseXentDev(16, 100000, gpu); -BM_SparseXentDev(32, 10000, gpu, DT_FLOAT); -BM_SparseXentDev(32, 30000, gpu, DT_FLOAT); -BM_SparseXentDev(32, 100000, gpu, DT_FLOAT); +BM_SparseXentDev(32, 10000, gpu); +BM_SparseXentDev(32, 30000, gpu); +BM_SparseXentDev(32, 100000, gpu); -BM_SparseXentDev(64, 10000, gpu, DT_FLOAT); -BM_SparseXentDev(64, 30000, gpu, DT_FLOAT); -BM_SparseXentDev(64, 100000, gpu, DT_FLOAT); +BM_SparseXentDev(64, 10000, gpu); +BM_SparseXentDev(64, 30000, gpu); +BM_SparseXentDev(64, 100000, gpu); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// CPU +BM_SparseXentDev(8, 1000000, cpu); + +BM_SparseXentDev(16, 10000, cpu); +BM_SparseXentDev(16, 100000, cpu); + +BM_SparseXentDev(32, 10000, cpu); +BM_SparseXentDev(32, 100000, cpu); + +BM_SparseXentDev(64, 10000, cpu); +BM_SparseXentDev(64, 100000, cpu); + } // end namespace tensorflow diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index b19c2630f23..e0d2e5b9341 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -485,7 +485,8 @@ void TFDefaultLogSink::Send(const TFLogEntry& entry) { __android_log_write(android_log_level, "native", ss.str().c_str()); // Also log to stderr (for standalone Android apps). - std::cerr << "native : " << ss.str() << std::endl; + // Don't use 'std::cerr' since it crashes on Android. + fprintf(stderr, "native : %s\n", ss.str().c_str()); // Android logging at level FATAL does not terminate execution, so abort() // is still required to stop the program. diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc index 4fe3ed58366..accbe4ecfba 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc @@ -21,37 +21,28 @@ limitations under the License. namespace tensorflow { namespace profiler { +namespace { -void MergeHostPlanes(XSpace* space) { - const XPlane* cupti_driver_api_plane = - FindPlaneWithName(*space, kCuptiDriverApiPlaneName); - const XPlane* python_tracer_plane = - FindPlaneWithName(*space, kPythonTracerPlaneName); - if (cupti_driver_api_plane || python_tracer_plane) { - XPlane* host_plane = - FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName); - if (cupti_driver_api_plane) { - MergePlanes(*cupti_driver_api_plane, host_plane); - } - if (python_tracer_plane) { - MergePlanes(*python_tracer_plane, host_plane); - } - SortXLinesBy(host_plane, XLinesComparatorByName()); - if (cupti_driver_api_plane) { - RemovePlane(space, cupti_driver_api_plane); - } - if (python_tracer_plane) { - RemovePlane(space, python_tracer_plane); - } +// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer. +void MergeHostPlanesAndSortLines(XSpace* space) { + XPlane* host_plane = + FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName); + std::vector additional_host_planes = FindPlanesWithNames( + *space, {kCuptiDriverApiPlaneName, kPythonTracerPlaneName}); + if (!additional_host_planes.empty()) { + MergePlanes(additional_host_planes, host_plane); + RemovePlanes(space, additional_host_planes); } + SortXLinesBy(host_plane, XLinesComparatorByName()); } +} // namespace + void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) { VLOG(3) << "Post processing local profiler XSpace."; // Post processing the collected XSpace without hold profiler lock. - // 1. Merge plane of host events with plane of CUPTI driver api. - MergeHostPlanes(space); - + // 1. Merge all host planes and sorts lines by name. + MergeHostPlanesAndSortLines(space); // 2. Normalize all timestamps by shifting timeline to profiling start time. // NOTE: this have to be done before sorting XSpace due to timestamp overflow. NormalizeTimestamps(space, start_time_ns); diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h index 70c6785591b..31ebe28c48f 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h @@ -21,9 +21,6 @@ limitations under the License. namespace tensorflow { namespace profiler { -// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer. -void MergeHostPlanes(XSpace* space); - // Post process XSpaces collected locally from multiple profilers. void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns); diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc index 8249bb68128..5b7d22ce22f 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.cc +++ b/tensorflow/core/profiler/utils/xplane_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" @@ -38,45 +39,84 @@ namespace { // Returns the index of the first element in array for which pred is true. // Returns -1 if no such element is found. template -int FindIf(const protobuf::RepeatedPtrField& array, Pred&& pred) { +int Find(const protobuf::RepeatedPtrField& array, const Pred& pred) { for (int i = 0; i < array.size(); ++i) { if (pred(&array.Get(i))) return i; } return -1; } +// Returns the indices of all elements in array for which pred is true. +template +std::vector FindAll(const protobuf::RepeatedPtrField& array, + const Pred& pred) { + std::vector indices; + for (int i = 0; i < array.size(); ++i) { + if (pred(&array.Get(i))) indices.push_back(i); + } + return indices; +} + +template +void RemoveAt(protobuf::RepeatedPtrField* array, + const std::vector& indices) { + if (indices.empty()) return; + if (array->size() == indices.size()) { + // Assumes that 'indices' consists of [0 ... N-1]. + array->Clear(); + return; + } + auto remove_iter = indices.begin(); + int i = *(remove_iter++); + for (int j = i + 1; j < array->size(); ++j) { + if (remove_iter != indices.end() && *remove_iter == j) { + ++remove_iter; + } else { + array->SwapElements(j, i++); + } + } + array->DeleteSubrange(i, array->size() - i); +} + // Removes the given element from array. template void Remove(protobuf::RepeatedPtrField* array, const T* elem) { - int i = FindIf(*array, [elem](const T* e) { return elem == e; }); - if (i == -1) return; - for (; i < array->size() - 1; ++i) { - array->SwapElements(i + 1, i); - } - array->RemoveLast(); + int i = Find(*array, [elem](const T* e) { return elem == e; }); + RemoveAt(array, {i}); } template void RemoveIf(protobuf::RepeatedPtrField* array, Pred&& pred) { - int i = FindIf(*array, pred); - if (i == -1) return; - for (int j = i + 1; j < array->size(); ++j) { - if (!pred(&array->Get(j))) array->SwapElements(j, i++); - } - array->DeleteSubrange(i, array->size() - i); + std::vector indices = FindAll(*array, pred); + RemoveAt(array, indices); } } // namespace const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) { - int i = FindIf(space.planes(), - [name](const XPlane* plane) { return plane->name() == name; }); + int i = Find(space.planes(), + [name](const XPlane* plane) { return plane->name() == name; }); return (i != -1) ? &space.planes(i) : nullptr; } +std::vector FindPlanesWithNames( + const XSpace& space, const std::vector& names) { + absl::flat_hash_set names_set(names.begin(), names.end()); + std::vector indices = + FindAll(space.planes(), [&names_set](const XPlane* plane) { + return names_set.contains(plane->name()); + }); + std::vector planes; + planes.reserve(indices.size()); + for (int i : indices) { + planes.push_back(&space.planes(i)); + } + return planes; +} + XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name) { - int i = FindIf(space->planes(), - [name](const XPlane* plane) { return plane->name() == name; }); + int i = Find(space->planes(), + [name](const XPlane* plane) { return plane->name() == name; }); return (i != -1) ? space->mutable_planes(i) : nullptr; } @@ -108,8 +148,8 @@ std::vector FindMutablePlanesWithPrefix(XSpace* space, } const XLine* FindLineWithId(const XPlane& plane, int64 id) { - int i = FindIf(plane.lines(), - [id](const XLine* line) { return line->id() == id; }); + int i = + Find(plane.lines(), [id](const XLine* line) { return line->id() == id; }); return (i != -1) ? &plane.lines(i) : nullptr; } @@ -129,6 +169,13 @@ void RemovePlane(XSpace* space, const XPlane* plane) { Remove(space->mutable_planes(), plane); } +void RemovePlanes(XSpace* space, const std::vector& planes) { + absl::flat_hash_set planes_set(planes.begin(), planes.end()); + RemoveIf(space->mutable_planes(), [&planes_set](const XPlane* plane) { + return planes_set.contains(plane); + }); +} + void RemoveLine(XPlane* plane, const XLine* line) { DCHECK(line != nullptr); Remove(plane->mutable_lines(), line); @@ -245,6 +292,13 @@ void MergePlanes(const XPlane& src_plane, XPlane* dst_plane) { }); } +void MergePlanes(const std::vector& src_planes, + XPlane* dst_plane) { + for (const XPlane* src_plane : src_planes) { + MergePlanes(*src_plane, dst_plane); + } +} + uint64 GetStartTimestampNs(const XPlane& plane) { int64 plane_timestamp = 0; for (const auto& line : plane.lines()) { diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index b6331e46452..8358c5cba8a 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -38,6 +38,10 @@ inline Timespan XEventTimespan(const XEvent& event) { const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name); XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name); +// Returns the planes with the given names, if found. +std::vector FindPlanesWithNames( + const XSpace& space, const std::vector& names); + // Returns the plane with the given name in the container. If necessary, adds a // new plane to the container. XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name); @@ -54,6 +58,7 @@ const XLine* FindLineWithId(const XPlane& plane, int64 id); XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event); void RemovePlane(XSpace* space, const XPlane* plane); +void RemovePlanes(XSpace* space, const std::vector& planes); void RemoveLine(XPlane* plane, const XLine* line); void RemoveEvents(XLine* line, const absl::flat_hash_set& events); @@ -107,12 +112,16 @@ std::vector GetSortedEvents(XPlane* plane, Compare comp, void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns); void NormalizeTimestamps(XSpace* space, uint64 start_time_ns); -// Merge Xplane src_plane into Xplane dst_plane, both plane level stats, lines, -// events and event level stats are merged; If src_plane and dst_plane both have -// the same line, which have different start timestamps, we will normalize the -// events offset timestamp correspondingly. +// Merges src_plane into dst_plane. Both plane level stats, lines, events and +// event level stats are merged. If src_plane and dst_plane both have the same +// line, which have different start timestamps, we will normalize the events +// offset timestamp correspondingly. void MergePlanes(const XPlane& src_plane, XPlane* dst_plane); +// Merges each plane with a src_planes, into the dst_plane. +void MergePlanes(const std::vector& src_planes, + XPlane* dst_plane); + // Plane's start timestamp is defined as the minimum of all lines' start // timestamps. If zero line exists, return 0; uint64 GetStartTimestampNs(const XPlane& plane); diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 29e3e8a4ce3..59ffa50cf4b 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -620,6 +620,11 @@ message ConfigProto { // The XLA fusion autotuner can improve performance by executing a heuristic // search on the compiler parameters. int64 xla_fusion_autotuner_thresh = 15; + + // Whether runtime execution uses TFRT. + bool use_tfrt = 18; + + // Next: 19 } Experimental experimental = 16; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 66572fab066..f7f89b452ac 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 608 // Updated: 2020/12/7 +#define TF_GRAPH_DEF_VERSION 609 // Updated: 2020/12/8 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index c99808f5ee7..069cf0d37e1 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -251,7 +251,7 @@ class TpuCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr module, stream_executor::StreamExecutor* executor, - stream_executor::DeviceMemoryAllocator* device_allocator) override { + const CompileOptions& options) override { XLA_HloModule hlo_module; XLA_HloModule result; auto cleanup = xla::MakeCleanup([&hlo_module, &result]() { @@ -261,7 +261,7 @@ class TpuCompiler : public Compiler { }); hlo_module.module_config = HloModuleConfigToC(module->config()); hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto()); - auto allocator = ApiConverter::ToC(device_allocator); + auto allocator = ApiConverter::ToC(options.device_allocator); StatusHelper status; ExecutorApiFn()->TpuCompiler_RunHloPassesFn( compiler_, &hlo_module, @@ -279,7 +279,7 @@ class TpuCompiler : public Compiler { StatusOr> RunBackend( std::unique_ptr module, stream_executor::StreamExecutor* executor, - stream_executor::DeviceMemoryAllocator* device_allocator) override { + const CompileOptions& options) override { XLA_HloModule hlo_module; auto cleanup = xla::MakeCleanup([&hlo_module]() { stream_executor::tpu::SerializedProto_Free(hlo_module.proto); @@ -288,7 +288,7 @@ class TpuCompiler : public Compiler { SE_Executable* result; hlo_module.module_config = HloModuleConfigToC(module->config()); hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto()); - auto allocator = ApiConverter::ToC(device_allocator); + auto allocator = ApiConverter::ToC(options.device_allocator); StatusHelper status; ExecutorApiFn()->TpuCompiler_RunBackendFn( @@ -308,7 +308,7 @@ class TpuCompiler : public Compiler { StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, - stream_executor::DeviceMemoryAllocator* device_allocator) override { + const CompileOptions& options) override { XLA_HloModuleGroup se_module_group; se_module_group.proto = stream_executor::tpu::SerializeProto(module_group->ToProto()); @@ -339,7 +339,8 @@ class TpuCompiler : public Compiler { } } - SE_DeviceMemoryAllocator allocator = ApiConverter::ToC(device_allocator); + SE_DeviceMemoryAllocator allocator = + ApiConverter::ToC(options.device_allocator); SE_Executable** se_executables = new SE_Executable*[module_group->size()]; diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 25eb88991a1..9db9c335812 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -20469,12 +20469,13 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t return output } -// Returns the TopK values in the array in sorted order. This is a combination +// Returns the TopK values in the array in sorted order. // -// of MakeUnique and TopKUnique. The returned top-K will have its lower bits -// replaced by iota, thus it will be close to the original value but not exactly -// the same. The running time is proportional to the product of K and the input -// size. NaNs are never returned. Subnormal numbers are flushed to zero. +// This is a combination of MakeUnique and TopKUnique. The returned top-K will +// have its lower bits replaced by iota, thus it will be close to the original +// value but not exactly the same. The running time is proportional to the product +// of K and the input size. NaNs are never returned. Subnormal numbers are flushed +// to zero. func TopKWithUnique(scope *Scope, input tf.Output, k int64) (topk tf.Output, topk_indices tf.Output) { if scope.Err() != nil { return @@ -29883,7 +29884,7 @@ func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_in return op.Output(0), op.Output(1), op.Output(2) } -// Adds v into specified rows of x. +// Adds v into specified rows of x. // // Computes y = x; y[i, :] += v; return y. // @@ -34871,9 +34872,9 @@ func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPURepli return scope.AddOperation(opspec) } -// Returns the TopK unique values in the array in sorted order. The +// Returns the TopK unique values in the array in sorted order. // -// running time is proportional to the product of K and the input +// The running time is proportional to the product of K and the input // size. Sorting the whole array is more efficient for sufficiently large // values of K. The median-of-medians algorithm is probably faster, but // difficult to implement efficiently in XLA. If there are fewer than K diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 6d884f32f83..df5b34cec89 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -215,25 +215,29 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } func (t *Tensor) Shape() []int64 { return t.shape } // Reshape updates tensor's shape in place if this is possible or returns an error otherwise. -func (t *Tensor) Reshape(new_shape []int64) error { - old_shape_size := numElements(t.shape) - new_shape_size := numElements(new_shape) +func (t *Tensor) Reshape(newShape []int64) error { + oldShapeSize := numElements(t.shape) + newShapeSize := numElements(newShape) - if old_shape_size != new_shape_size { - return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size) + if oldShapeSize != newShapeSize { + return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize) } - if len(new_shape) == 0 { + if len(newShape) == 0 { return nil } var shapePtr *C.int64_t - shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0])) + shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0])) status := newStatus() - C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c) + C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c) - return status.Err() + if err := status.Err(); err != nil { + return err + } + t.shape = newShape + return nil } // Value converts the Tensor to a Go value. For now, not all Tensor types are diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 15b2ea55ad8..8aa710669a0 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -358,3 +358,31 @@ func BenchmarkTensor(b *testing.B) { }) } + +func TestReshape(t *testing.T) { + tensor, err := NewTensor([]int64{1, 2}) + if err != nil { + t.Fatalf("Unable to create new tensor: %v", err) + } + + if got, want := len(tensor.Shape()), 1; got != want { + t.Fatalf("len(tensor.Shape()): got %d, want %d", got, want) + } + if got, want := tensor.Shape()[0], int64(2); got != want { + t.Errorf("tensor.Shape()[0]: got %d, want %d", got, want) + } + + if err := tensor.Reshape([]int64{1, 2}); err != nil { + t.Fatalf("tensor.Reshape([1, 2]) failed: %v", err) + } + + if got, want := len(tensor.Shape()), 2; got != want { + t.Fatalf("After reshape, len(tensor.Shape()): got %d, want %d", got, want) + } + if got, want := tensor.Shape()[0], int64(1); got != want { + t.Errorf("After reshape, tensor.Shape()[0]: got %d, want %d", got, want) + } + if got, want := tensor.Shape()[1], int64(2); got != want { + t.Errorf("After reshape, tensor.Shape()[1]: got %d, want %d", got, want) + } +} diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index a07ded54372..8e3bcc27567 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -14,7 +14,9 @@ exports_files(glob([ "testdata/*.tflite", "testdata/*.csv", "models/testdata/*", -])) +]) + [ + "create_op_resolver.h", +]) config_setting( name = "gemmlowp_profiling", @@ -707,6 +709,34 @@ cc_library( ], ) +# Defines CreateOpResolver with all builtin ops. +cc_library( + name = "create_op_resolver_with_builtin_ops", + srcs = ["create_op_resolver_with_builtin_ops.cc"], + hdrs = ["create_op_resolver.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/lite:op_resolver", + "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:builtin_ops", + ], +) + +# This target is created for tflite_custom_cc_library build rule. It requires +# the header file generated from gen_selected_ops so should not be depended on +# directly. +# TODO(b/174972014): Generate this target to give RegisterSelectedOps a custom namespace. +cc_library( + name = "create_op_resolver_with_selected_ops", + srcs = ["create_op_resolver_with_selected_ops.cc"], + hdrs = ["create_op_resolver.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/lite:mutable_op_resolver", + "//tensorflow/lite:op_resolver", + ], +) + cc_test( name = "util_test", size = "small", diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index aeb271d5dd0..92810014709 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -216,23 +216,6 @@ if(TFLITE_ENABLE_GPU) "delegates/gpu/cl/kernels" TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS FILTER "(_test)\\.(cc|h)$" ) - populate_tflite_source_vars( - "delegates/gpu/cl/kernels/special" - TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS - FILTER "(_test)\\.(cc|h)$" - ) - populate_tflite_source_vars( - "delegates/gpu/cl/selectors" TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS - FILTER "(_test)\\.(cc|h)$" - ) - populate_tflite_source_vars( - "delegates/gpu/cl/selectors/default" TFLITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SRCS - FILTER "(_test)\\.(cc|h)$" - ) - populate_tflite_source_vars( - "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS - FILTER "(_test)\\.(cc|h)$" - ) populate_tflite_source_vars( "delegates/gpu/common/default" TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS FILTER "(_test)\\.(cc|h)$" @@ -242,6 +225,18 @@ if(TFLITE_ENABLE_GPU) TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS FILTER "(_test)\\.(cc|h)$" ) + populate_tflite_source_vars( + "delegates/gpu/common/selectors" TFLITE_DELEGATES_GPU_COMMON_SELECTORS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/selectors/default" TFLITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS + FILTER "(_test)\\.(cc|h)$" + ) populate_tflite_source_vars( "delegates/gpu/common/task" TFLITE_DELEGATES_GPU_COMMON_TASK_SRCS @@ -267,12 +262,11 @@ if(TFLITE_ENABLE_GPU) ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc ${TFLITE_DELEGATES_GPU_CL_SRCS} ${TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS} - ${TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS} - ${TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS} - ${TFLITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SRCS} - ${TFLITE_DELEGATES_GPU_COMMON_SRCS} ${TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS} ${TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_SELECTORS_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_SRCS} ${TFLITE_DELEGATES_GPU_COMMON_TASK_SRCS} ${TFLITE_DELEGATES_GPU_COMMON_TASKS_SRCS} ${TFLITE_DELEGATES_GPU_COMMON_TASKS_SPECIAL_SRCS} diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 9f729d99994..5895b7808de 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -800,7 +800,7 @@ def tflite_custom_cc_library( model = models, ) real_srcs.append(":%s_registration" % name) - real_deps.append("//tensorflow/lite/java/src/main/native:selected_ops_jni") + real_deps.append("//tensorflow/lite:create_op_resolver_with_selected_ops") else: # Support all operators if `models` not specified. real_deps.append("//tensorflow/lite/java/src/main/native") @@ -810,7 +810,7 @@ def tflite_custom_cc_library( srcs = real_srcs, hdrs = [ # TODO(b/161323860) replace this by generated header. - "//tensorflow/lite/java/src/main/native:op_resolver.h", + "//tensorflow/lite:create_op_resolver.h", ], copts = tflite_copts(), linkopts = select({ diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index a4bfc77a614..14000f93cd1 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -201,6 +201,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseDequantize(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_EXP: { + return ParseExp(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_FILL: { return ParseFill(op, error_reporter, allocator, builtin_data); } @@ -796,7 +800,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_ELU: case BuiltinOperator_EMBEDDING_LOOKUP: case BuiltinOperator_EQUAL: - case BuiltinOperator_EXP: case BuiltinOperator_EXPAND_DIMS: case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_MATRIX_DIAG: @@ -1097,6 +1100,14 @@ TfLiteStatus ParseEqual(const Operator*, ErrorReporter*, BuiltinDataAllocator*, return kTfLiteOk; } +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseExp(const Operator*, ErrorReporter*, BuiltinDataAllocator*, + void**) { + return kTfLiteOk; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 2540d3524aa..33956624373 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -107,6 +107,9 @@ TfLiteStatus ParseDequantize(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseEqual(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseExp(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); + TfLiteStatus ParseFill(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/lite/create_op_resolver.h similarity index 81% rename from tensorflow/lite/java/src/main/native/op_resolver.h rename to tensorflow/lite/create_op_resolver.h index 08ff0cec29e..b8736e4c6bc 100644 --- a/tensorflow/lite/java/src/main/native/op_resolver.h +++ b/tensorflow/lite/create_op_resolver.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_ -#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_ #include @@ -25,4 +25,4 @@ std::unique_ptr CreateOpResolver(); } -#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_ +#endif // TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_ diff --git a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc similarity index 68% rename from tensorflow/lite/java/src/main/native/builtin_ops_jni.cc rename to tensorflow/lite/create_op_resolver_with_builtin_ops.cc index eb17fdcf2a5..5801fad369b 100644 --- a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc +++ b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc @@ -15,17 +15,16 @@ limitations under the License. #include -#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/create_op_resolver.h" #include "tensorflow/lite/kernels/register.h" namespace tflite { -// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in -// the tflite namespace. This one instantiates a -// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but -// without applying any TfLite delegates by default (like the XNNPACK delegate). -// For smaller binary sizes users should avoid linking this in, and should -// provide a custom make CreateOpResolver() instead. +// This function instantiates a BuiltinOpResolverWithoutDefaultDelegates, with +// all the builtin ops but without applying any TfLite delegates by default +// (like the XNNPACK delegate). For smaller binary sizes users should avoid +// linking this in, and should provide a CreateOpResolver() with selected ops +// instead. std::unique_ptr CreateOpResolver() { // NOLINT return std::unique_ptr( new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); diff --git a/tensorflow/lite/java/src/main/native/selected_ops_jni.cc b/tensorflow/lite/create_op_resolver_with_selected_ops.cc similarity index 91% rename from tensorflow/lite/java/src/main/native/selected_ops_jni.cc rename to tensorflow/lite/create_op_resolver_with_selected_ops.cc index d8eb233f90a..5c7d95c6382 100644 --- a/tensorflow/lite/java/src/main/native/selected_ops_jni.cc +++ b/tensorflow/lite/create_op_resolver_with_selected_ops.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/java/src/main/native/op_resolver.h" +#include "tensorflow/lite/create_op_resolver.h" #include "tensorflow/lite/mutable_op_resolver.h" // This method is generated by `gen_selected_ops`. -// TODO(b/153652701): Instead of relying on a global method, make +// TODO(b/174972014): Instead of relying on a global method, make // `gen_selected_ops` generating a header file with custom namespace. void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 069230ebcf6..2c0080c0bfc 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -92,6 +92,7 @@ objc_library( "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:quantization_util", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 989a1632536..65ea74e4ef3 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -395,8 +395,6 @@ cc_library( ":opencl_wrapper", ":serialization_cc_fbs", ":tensor", - "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector", - "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:memory_management", "//tensorflow/lite/delegates/gpu/common:model", @@ -409,6 +407,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/selectors:operation_selector", + "//tensorflow/lite/delegates/gpu/common/selectors:special_selector", "//tensorflow/lite/delegates/gpu/common/task:arguments", "//tensorflow/lite/delegates/gpu/common/task:buffer_desc", "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc", diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index f24eeb27b8a..2b966bedc9c 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -27,14 +27,14 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/memory_management.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h" diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 97eb0751328..5fca6297171 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1534,10 +1534,10 @@ class ReduceOperationParser : public TFLiteOperationParser { ReduceAttributes attr; Tensor axes; RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); - const TfLiteTensor* output = reader->GetOutputTensor(0); + const TfLiteTensor* input = reader->GetInputTensor(0); for (int i = 0; i < axes.data.size(); i++) { Axis axis; - RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis)); + RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis)); attr.dims.insert(axis); } node->operation.attributes = attr; @@ -2615,10 +2615,10 @@ class MeanOperationParser : public TFLiteOperationParser { MeanAttributes attr; Tensor axes; RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); - const TfLiteTensor* output = reader->GetOutputTensor(0); + const TfLiteTensor* input = reader->GetInputTensor(0); for (int i = 0; i < axes.data.size(); i++) { Axis axis; - RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis)); + RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis)); attr.dims.insert(axis); } node->operation.attributes = attr; diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/BUILD similarity index 89% rename from tensorflow/lite/delegates/gpu/cl/selectors/BUILD rename to tensorflow/lite/delegates/gpu/common/selectors/BUILD index 8be72284312..0198d621839 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/common/selectors/BUILD @@ -7,11 +7,11 @@ cc_library( name = "convolution_selector", hdrs = ["convolution_selector.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/selectors/default:convolution_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors/default:convolution_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:weights_layout", ], @@ -21,9 +21,9 @@ cc_library( name = "convolution_transposed_selector", hdrs = ["convolution_transposed_selector.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/selectors/default:convolution_transposed_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors/default:convolution_transposed_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:weights_layout", "@com_google_absl//absl/memory", @@ -35,10 +35,10 @@ cc_library( hdrs = ["default_selector.h"], deps = [ ":subgraph", - "//tensorflow/lite/delegates/gpu/cl/selectors/default:default_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors/default:default_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", ], @@ -48,9 +48,9 @@ cc_library( name = "dw_convolution_selector", hdrs = ["dw_convolution_selector.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/selectors/default:dw_convolution_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors/default:dw_convolution_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], @@ -60,9 +60,9 @@ cc_library( name = "fully_connected_selector", hdrs = ["fully_connected_selector.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/selectors/default:fully_connected_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors/default:fully_connected_selector", # buildcleaner: keep "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], @@ -80,8 +80,8 @@ cc_library( ":fully_connected_selector", ":simple_selectors", ":subgraph", - "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:operations", @@ -105,7 +105,7 @@ cc_library( srcs = ["simple_selectors.cc"], hdrs = ["simple_selectors.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl:cl_device", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", @@ -141,8 +141,8 @@ cc_library( hdrs = ["special_selector.h"], deps = [ ":subgraph", - "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h similarity index 88% rename from tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h index cef1e014217..0fa11f53d13 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_ #include @@ -27,7 +27,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectConvolution( const Convolution2DAttributes& attr, const BHWC& dst_shape, @@ -46,8 +45,7 @@ std::unique_ptr SelectConverterToConvWeights( const WeightsDescription& weights_desc, const OperationDef& op_def, ModelHints hints); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h similarity index 82% rename from tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h index 4a2a6d9645f..5c94b898848 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ #include @@ -25,7 +25,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectConvolutionTransposed( const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info, @@ -35,8 +34,7 @@ std::unique_ptr SelectConvolutionTransposedWithDynamicWeights( const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info, const OperationDef& op_def, WeightsDescription* weights_desc); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_CONVOLUTION_TRANSPOSED_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD similarity index 96% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD rename to tensorflow/lite/delegates/gpu/common/selectors/default/BUILD index 33edadf1900..0bcc41c5da2 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/BUILD @@ -46,11 +46,11 @@ cc_library( name = "default_selector", srcs = ["default_selector.cc"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/selectors:subgraph", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/selectors:subgraph", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], @@ -60,7 +60,7 @@ cc_library( name = "dw_convolution_selector", srcs = ["dw_convolution_selector.cc"], deps = [ - "//tensorflow/lite/delegates/gpu/cl:cl_device", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc index f76fc563608..9f0fdb56a69 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_selector.cc @@ -30,7 +30,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { std::unique_ptr SelectConvolutionAdreno( @@ -201,6 +200,5 @@ std::unique_ptr SelectConverterToConvWeights( return absl::make_unique(std::move(converter)); } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc index e33d8488320..d4205ed1c1b 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/convolution_transposed_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/convolution_transposed_selector.cc @@ -23,7 +23,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { std::unique_ptr SelectConvolutionTransposedAdreno( @@ -142,6 +141,5 @@ std::unique_ptr SelectConvolutionTransposedWithDynamicWeights( } } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc similarity index 93% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc index a7d94fabf43..222393872eb 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/default_selector.cc @@ -16,16 +16,15 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { -namespace cl { absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, const std::vector& inputs, @@ -35,6 +34,5 @@ absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def, absl::StrCat("No selector for ", node.operation.type)); } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc similarity index 97% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc index 968d0614fd8..07b11134618 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/dw_convolution_selector.cc @@ -14,13 +14,12 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.h" #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3.h" namespace tflite { namespace gpu { -namespace cl { namespace { std::unique_ptr SelectDWConvolutionAdreno( @@ -79,6 +78,5 @@ std::unique_ptr SelectDWConvolution( } } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc similarity index 99% rename from tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc index 43f6a2613d2..863409393f2 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/default/fully_connected_selector.cc @@ -22,7 +22,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectFullyConnectedGeneric( const FullyConnectedAttributes& attr, const GpuInfo& gpu_info, @@ -96,6 +95,5 @@ std::unique_ptr SelectFullyConnected( } } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/default_selector.h similarity index 81% rename from tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/default_selector.h index 1efa215e602..c6f7758cdc2 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/default_selector.h @@ -13,29 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_ #include -#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { namespace gpu { -namespace cl { absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, const std::vector& inputs, const std::vector& outputs, const Node& node, GPUOperationsSubgraph* gpu_subgraph); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DEFAULT_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DEFAULT_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h similarity index 80% rename from tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h index 0c920984bc1..f3e50a9c665 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ #include @@ -24,14 +24,12 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectDWConvolution( const DepthwiseConvolution2DAttributes& attr, const GpuInfo& gpu_info, const OperationDef& op_def); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_DW_CONVOLUTION_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h similarity index 80% rename from tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h index 5b1563a9351..e2e910e4d12 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ #include @@ -24,14 +24,12 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectFullyConnected( const FullyConnectedAttributes& attr, const GpuInfo& gpu_info, const OperationDef& op_def, int batch_size); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_FULLY_CONNECTED_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc similarity index 97% rename from tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc index 3fca1c64f17..41c6937fb2b 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h" #include "absl/strings/str_cat.h" #include "absl/types/any.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h" -#include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/convolution_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/convolution_transposed_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/default_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/dw_convolution_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h" @@ -38,7 +38,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr, const GpuInfo& gpu_info, @@ -530,6 +529,5 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, } } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h similarity index 82% rename from tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h index b81bdaa0506..dfffa9be0a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h @@ -13,21 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_ #include -#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { namespace gpu { -namespace cl { absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, @@ -36,8 +35,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, const Node& node, GPUOperationsSubgraph* gpu_subgraph); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_OPERATION_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_OPERATION_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc similarity index 98% rename from tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc rename to tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc index 48bb8fc167e..6f7baefab92 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h" #include #include @@ -44,7 +44,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectLSTM(const OperationDef& op_def, const GpuInfo& gpu_info) { @@ -194,6 +193,5 @@ std::unique_ptr SelectQuantizeAndDequantize( CreateQuantizeAndDequantize(op_def, attr)); } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h similarity index 93% rename from tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h rename to tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h index 52c23102dd9..4f757a88477 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_ #include -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -26,7 +26,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr SelectLSTM(const OperationDef& op_def, const GpuInfo& gpu_info); @@ -98,8 +97,7 @@ std::unique_ptr SelectWinograd36To4x4( std::unique_ptr SelectQuantizeAndDequantize( const QuantizeAndDequantizeAttributes& attr, const OperationDef& op_def); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SIMPLE_SELECTORS_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SIMPLE_SELECTORS_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc similarity index 98% rename from tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc rename to tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc index 6d5300da48b..14160018886 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h" #include "absl/types/any.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" @@ -28,7 +27,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { absl::Status TryDepthwiseConvPlus1x1Conv( CalculationsPrecision precision, const GraphFloat32& graph, @@ -208,6 +206,5 @@ absl::Status GPUSubgraphFromGraph( return absl::NotFoundError("No special combination."); } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.h similarity index 79% rename from tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h rename to tensorflow/lite/delegates/gpu/common/selectors/special_selector.h index aecd0a0a519..fc33d51058e 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.h @@ -13,22 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_ #include #include #include -#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { namespace gpu { -namespace cl { absl::Status GPUSubgraphFromGraph( const GpuInfo& gpu_info, CalculationsPrecision precision, @@ -37,8 +37,7 @@ absl::Status GPUSubgraphFromGraph( std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph, std::string* name); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SPECIAL_SELECTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc similarity index 93% rename from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc rename to tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc index cd3c987ccaf..5bb11b6081f 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" +#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" #include @@ -23,7 +23,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { std::unique_ptr* InitSingleOpSubgraph( const std::vector& inputs, const std::vector& outputs, @@ -41,6 +40,5 @@ std::unique_ptr* InitSingleOpSubgraph( return &gpu_subgraph->operations[0].operation; } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.h similarity index 87% rename from tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h rename to tensorflow/lite/delegates/gpu/common/selectors/subgraph.h index f94e0c430a3..243e2d4b496 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/subgraph.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_ #include #include @@ -25,7 +25,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { struct GPUOperationWithRefs { std::unique_ptr operation; @@ -46,8 +45,7 @@ std::unique_ptr* InitSingleOpSubgraph( const std::vector& inputs, const std::vector& outputs, GPUOperationsSubgraph* gpu_subgraph); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SUBGRAPH_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SELECTORS_SUBGRAPH_H_ diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc index 7352e7fb64b..32424a96f72 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.cc @@ -23,7 +23,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { void UploadWeights(const DepthwiseConvolution2DAttributes& dw_attr, const Convolution2DAttributes& conv_attr, @@ -264,6 +263,5 @@ GPUOperation CreateDepthwiseConvPlus1x1Conv( return result; } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h index 93ef127bd52..c8912618a6b 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h @@ -30,7 +30,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { bool IsDepthwiseConvPlus1x1ConvSupported( const OperationDef& definition, @@ -42,7 +41,6 @@ GPUOperation CreateDepthwiseConvPlus1x1Conv( const DepthwiseConvolution2DAttributes& dw_attr, const Convolution2DAttributes& conv_attr); -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc index 7a10f49edfd..a632dff0d29 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.cc @@ -27,7 +27,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { bool UseBufferForWeights(const GpuInfo& gpu_info) { return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali(); @@ -195,6 +194,5 @@ FCFCAdd CreateFCFCAdd(const GpuInfo& gpu_info, const OperationDef& definition, return result; } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h index c447c802c96..5c14729e74c 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h @@ -35,7 +35,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { template void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor& weights, @@ -176,7 +175,6 @@ FCFCAdd CreateFCFCAdd(const GpuInfo& gpu_info, const OperationDef& definition, const FullyConnectedAttributes& attr0, const FullyConnectedAttributes& attr1); -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index 6dcde34a62f..609bf4f63c2 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -26,10 +26,10 @@ cc_library( deps = [ ":compiled_model", ":compute_task_descriptor", - ":runtime_options", "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:util", @@ -163,8 +163,8 @@ objc_library( ":common", ":compute_task_descriptor", ":metal_arguments", - ":runtime_options", "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", @@ -211,9 +211,9 @@ objc_library( ":compiled_model", ":compute_task", ":compute_task_descriptor", - ":runtime_options", "//tensorflow/lite/delegates/gpu/common:memory_management", "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:util", @@ -292,11 +292,6 @@ objc_library( ], ) -cc_library( - name = "runtime_options", - hdrs = ["runtime_options.h"], -) - objc_library( name = "TestBinary", testonly = 1, @@ -342,7 +337,6 @@ objc_library( "//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:inference_context", "//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor", - "//tensorflow/lite/delegates/gpu/metal:runtime_options", "//tensorflow/lite/delegates/gpu/metal/kernels:test_util", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index b04632b9c91..ff8bedf6f89 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -48,7 +48,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { @@ -183,7 +182,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, const std::vector& inputs, const std::vector& outputs, const GpuInfo& gpu_info, - const RuntimeOptions& options, + CalculationsPrecision precision, int* last_value_id, std::map* tensor_shapes, std::vector* nodes) { @@ -199,15 +198,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, node_desc.src_tensors_ids = inputs; node_desc.dst_tensors_ids = outputs; OperationDef op_def; - if (options.storage_precision == RuntimeOptions::Precision::FP32) { - op_def.precision = CalculationsPrecision::F32; - } else { - if (options.accumulator_precision == RuntimeOptions::Precision::FP32) { - op_def.precision = CalculationsPrecision::F32_F16; - } else { - op_def.precision = CalculationsPrecision::F16; - } - } + op_def.precision = precision; DataType data_type = DeduceDataTypeFromPrecision(op_def.precision); TensorDescriptor tensor_descriptor = TensorDescriptor{data_type, TensorStorageType::BUFFER, Layout::HWC}; @@ -536,7 +527,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, } // namespace absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info, - const RuntimeOptions& options, + CalculationsPrecision precision, CompiledModel* compiled_model) { int last_value_id = 0; for (const auto& value : graph.values()) { @@ -555,11 +546,11 @@ absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info, } std::vector node_descs; std::vector custom_tasks; - auto custom_status = - RegisterCustomOps(graph, node, inputs, outputs, options, &custom_tasks); + auto custom_status = RegisterCustomOps(graph, node, inputs, outputs, + precision, &custom_tasks); if (!custom_status.ok()) { auto primary_status = RegisterPrimaryOps( - graph, node, inputs, outputs, gpu_info, options, &last_value_id, + graph, node, inputs, outputs, gpu_info, precision, &last_value_id, &compiled_model->tensor_shapes, &node_descs); if (!primary_status.ok()) { return absl::UnimplementedError( diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h index f7cdfa4245a..a2ef5c2f05f 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.h +++ b/tensorflow/lite/delegates/gpu/metal/api.h @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { @@ -28,7 +28,7 @@ namespace metal { // Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions. absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info, - const RuntimeOptions& options, + CalculationsPrecision precision, CompiledModel* compiled_model); } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.h b/tensorflow/lite/delegates/gpu/metal/compute_task.h index b3c32f49b31..73e9d81c76f 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.h @@ -24,17 +24,17 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" @interface TFLComputeTask : NSObject /// Returns empty string or error if shader can't be compiled. - (absl::Status)compileWithDevice:(id)device taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc - runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; + precision:(tflite::gpu::CalculationsPrecision)precision; /// Updates parameters for inputs/outputs/intermediate tensors - (absl::Status)updateParamsWithDevice:(id)device diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm index 62a6a617a42..388ca951d4d 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/common.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::AlignByN; using ::tflite::gpu::BHWC; @@ -34,7 +33,7 @@ using ::tflite::gpu::HalfBits; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; using ::tflite::gpu::metal::CreateComputeProgram; using ::tflite::gpu::metal::DispatchParamsFunction; -using ::tflite::gpu::metal::RuntimeOptions; +using ::tflite::gpu::CalculationsPrecision; using ::tflite::gpu::metal::UniformsFunction; using ::tflite::gpu::uint3; using ::tflite::gpu::ValueId; @@ -73,7 +72,7 @@ struct UniformBuffer { - (absl::Status)compileWithDevice:(id)device taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc - runtimeOptions:(const RuntimeOptions&)options { + precision:(CalculationsPrecision)precision; { size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size() + desc.task->immutable_buffers.size() + 1; RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source)); @@ -90,13 +89,13 @@ struct UniformBuffer { NSString* toAccumulatorType2 = @""; NSString* toAccumulatorType3 = @""; NSString* toAccumulatorType4 = @""; - if (options.storage_precision == RuntimeOptions::Precision::FP32) { + if (precision == CalculationsPrecision::F32) { storageType = @"float"; accumulatorType = @"float"; } else { // FP16 storageType = @"half"; - if (options.accumulator_precision == RuntimeOptions::Precision::FP32) { + if (precision == CalculationsPrecision::F32_F16) { accumulatorType = @"float"; toAccumulatorType = @"float"; toAccumulatorType2 = @"float2"; @@ -136,10 +135,9 @@ struct UniformBuffer { _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function}); } _outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil}); + const bool f32_storage = precision == CalculationsPrecision::F32; for (auto& immutable : desc.task->immutable_buffers) { - int padding = - 4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float) - : sizeof(HalfBits)); + int padding = 4 * (f32_storage ? sizeof(float) : sizeof(HalfBits)); int paddedSize = AlignByN(immutable.data.size(), padding); immutable.data.resize(paddedSize); id metalBuffer = [device newBufferWithBytes:immutable.data.data() diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h index c215a9123b0..54cee468728 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.h +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -23,11 +23,11 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" /// Stages of model preprocessing: /// 1. Operations' initialization. All operations are initialized and added into @@ -56,20 +56,34 @@ limitations under the License. model:(const tflite::gpu::metal::CompiledModel&)compiledModel inputBufferIDs:(const std::vector&)inputBufferIDs outputBufferIDs:(const std::vector&)outputBufferIDs - runtimeOptions:(const tflite::gpu::metal::RuntimeOptions&)options; + precision:(tflite::gpu::CalculationsPrecision)precision; /// Inserts all GPU compute tasks into the command encoder. /// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer -/// @param encoderBlock User-defined block to take control over command encoder. Can be nil. -/// The block can be used, for example, for fine-grained benchmarking where end encoding -/// is performed and command buffer is committed with completion block. A new command -/// buffer must be created and new command encoder must be returned by the block. -/// The block is called after every dispatch encoding. /// @discussion No GPU synchronization functions are used inside. All GPU resources must be created /// with the same device which has been used in compileModelWithDevice() method. - (void)encodeWithEncoder:(id)commandEncoder - inputOutputBuffers:(const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers - encoderBlock:(id (^)(bool isLast))encoderBlock; + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers; + +/// Inserts all GPU compute tasks into the command buffer. For every task will be used separate +/// encoder. +/// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer +/// @discussion No GPU synchronization functions are used inside. All GPU resources must be created +/// with the same device which has been used in compileModelWithDevice() method. +- (void)encodeWithCommandBuffer:(id)commandBuffer + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers; + +/// Adds all GPU compute tasks to the command queue. For every task will be used separate +/// encoder. Few encoders(flushPeriod) batched into compute buffer that sent for execution. +/// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer +/// @discussion No GPU synchronization functions are used inside. All GPU resources must be created +/// with the same device which has been used in compileModelWithDevice() method. +- (void)encodeWithCommandQueue:(id)commandQueue + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers + flushPeriodically:(int)flushPeriod; @end diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.mm b/tensorflow/lite/delegates/gpu/metal/inference_context.mm index 84322a4ee6f..1f8a2a9b109 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.mm +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.mm @@ -22,16 +22,16 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/memory_management.h" #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; -using ::tflite::gpu::metal::RuntimeOptions; +using ::tflite::gpu::CalculationsPrecision; using ::tflite::gpu::ValueId; using ::tflite::gpu::AlignByN; using ::tflite::gpu::HalfBits; @@ -45,7 +45,7 @@ using ::tflite::gpu::TensorUsageRecord; std::vector _inputIds; std::vector _outputIds; id _device; - RuntimeOptions _options; + CalculationsPrecision _precision; std::map _tensorShapes; } @@ -53,17 +53,17 @@ using ::tflite::gpu::TensorUsageRecord; model:(const tflite::gpu::metal::CompiledModel&) compiledModel inputBufferIDs:(const std::vector&)inputBufferIDs outputBufferIDs:(const std::vector&)outputBufferIDs - runtimeOptions:(const RuntimeOptions&)options { + precision:(tflite::gpu::CalculationsPrecision)precision { _device = device; _inputIds = inputBufferIDs; _outputIds = outputBufferIDs; - _options = options; + _precision = precision; // Metal resources are created here. for (const auto& node : compiledModel.nodes) { TFLComputeTask* task = [[TFLComputeTask alloc] init]; RETURN_IF_ERROR([task compileWithDevice:_device taskDescriptor:node - runtimeOptions:_options]); + precision:_precision]); [task setDescription:node.description]; _computeTasks.emplace_back(task); } @@ -119,9 +119,8 @@ using ::tflite::gpu::TensorUsageRecord; RETURN_IF_ERROR(AssignObjectsToTensors(usageRecords, MemoryStrategy::GREEDY_BEST, &assignment)); auto objectsCount = assignment.object_sizes.size(); std::vector> sharedBuffers(objectsCount); - size_t dataTypeSize = _options.storage_precision == RuntimeOptions::Precision::FP32 - ? sizeof(float) - : sizeof(HalfBits); + const bool f32_storage = _precision == CalculationsPrecision::F32; + size_t dataTypeSize = f32_storage ? sizeof(float) : sizeof(HalfBits); // allocate buffers for each shared object for (size_t i = 0; i < objectsCount; ++i) { @@ -165,8 +164,8 @@ using ::tflite::gpu::TensorUsageRecord; } - (void)encodeWithEncoder:(id)commandEncoder - inputOutputBuffers:(const std::map>&)inputOutputBuffers - encoderBlock:(id (^)(bool isLast))encoderBlock { + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers { for (auto& task_index : _taskIdsWithInOutBuffers) { auto& task = _computeTasks[task_index]; [task updateBuffers:inputOutputBuffers]; @@ -174,10 +173,44 @@ using ::tflite::gpu::TensorUsageRecord; for (int i = 0; i < _computeTasks.size(); ++i) { auto& task = _computeTasks[i]; [task encodeWithEncoder:commandEncoder]; - if (encoderBlock != nil) { - commandEncoder = encoderBlock(i == _computeTasks.size() - 1); - } } } +- (void)encodeWithCommandBuffer:(id)commandBuffer + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers { + for (auto& task_index : _taskIdsWithInOutBuffers) { + auto& task = _computeTasks[task_index]; + [task updateBuffers:inputOutputBuffers]; + } + for (int i = 0; i < _computeTasks.size(); ++i) { + id encoder = [commandBuffer computeCommandEncoder]; + auto& task = _computeTasks[i]; + [task encodeWithEncoder:encoder]; + [encoder endEncoding]; + } +} + +- (void)encodeWithCommandQueue:(id)commandQueue + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers + flushPeriodically:(int)flushPeriod { + for (auto& task_index : _taskIdsWithInOutBuffers) { + auto& task = _computeTasks[task_index]; + [task updateBuffers:inputOutputBuffers]; + } + id commandBuffer = [commandQueue commandBuffer]; + for (int i = 0; i < _computeTasks.size(); ++i) { + id encoder = [commandBuffer computeCommandEncoder]; + auto& task = _computeTasks[i]; + [task encodeWithEncoder:encoder]; + [encoder endEncoding]; + if (i % flushPeriod == (flushPeriod - 1)) { + [commandBuffer commit]; + commandBuffer = [commandQueue commandBuffer]; + } + } + [commandBuffer commit]; +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 4dd0ed166ab..20c59e2fbf9 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -174,9 +174,9 @@ cc_library( hdrs = ["custom_registry.h"], deps = [ "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", - "//tensorflow/lite/delegates/gpu/metal:runtime_options", ], ) @@ -814,6 +814,7 @@ objc_library( "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", @@ -823,7 +824,6 @@ objc_library( "//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:compiled_model", "//tensorflow/lite/delegates/gpu/metal:inference_context", - "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@FP16", "@com_google_absl//absl/memory", ], @@ -897,12 +897,12 @@ objc_library( deps = [ ":test_util", "//tensorflow/lite/delegates/gpu/common:gpu_info", + "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:inference_context", - "//tensorflow/lite/delegates/gpu/metal:runtime_options", ], ) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm index 22a798c59cc..3facbc4a7bf 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::ElementwiseAttributes; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm index 195a2986628..6ac084cf8bb 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::Axis; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm index 71ea6f9ede8..6775fd37135 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::Axis; using ::tflite::gpu::BHWC; @@ -286,9 +285,6 @@ using ::tflite::gpu::metal::SingleOpModel; } id device = MTLCreateSystemDefaultDevice(); - tflite::gpu::metal::RuntimeOptions options; - options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; std::map inputs_v0; inputs_v0[0] = src_tensor; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc index 620a4581c52..fa97160b3a2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { @@ -29,7 +29,7 @@ namespace metal { absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, const std::vector& inputs, const std::vector& outputs, - const RuntimeOptions& options, + CalculationsPrecision precision, std::vector* tasks) { return absl::UnimplementedError("Unsupported op: " + node->operation.type); } diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h index eee1632a644..2f08b74051c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/custom_registry.h @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { @@ -31,7 +31,7 @@ namespace metal { absl::Status RegisterCustomOps(const GraphFloat32& graph, const Node* node, const std::vector& inputs, const std::vector& outputs, - const RuntimeOptions& options, + CalculationsPrecision precision, std::vector* tasks); } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm index dcf550f7868..817a3713ceb 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::Axis; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm index 867ed596ed8..5826e2b6443 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::DataType; using ::tflite::gpu::HWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm index e57f9aa84e2..b6e4cb9d961 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm index cf4aacf724f..5ee3603f9da 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm index 67325c1adb7..e4fa301101c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::Axis; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm index 9c55cfc45b0..e8c0ef60707 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm index d2d95b30af2..a28dd642124 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm index 1df08be61db..3a01ca29b31 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm index 7a16f1d25b8..7eb71bf8e3b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize_test.mm @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" using ::tflite::NudgeQuantizationRange; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm index 52de77e0ee4..d685a8c211c 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm index 684e83b2db1..9a64ef52110 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm index 082f2c80125..f087777fa51 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/resize_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm index e0c29561f9b..25b45d4d921 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm index 9196e9fe094..c5b2fd00212 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::Axis; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm index 17e398817b2..b7c474e5128 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::BHWC; using ::tflite::gpu::DataType; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h index 14b64d37d26..bf8cbc3020a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/inference_context.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm index 912910ce28f..1533b74f45b 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.mm @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/inference_context.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" namespace tflite { @@ -84,11 +84,9 @@ absl::Status SingleOpModel::Invoke() { std::string device_name = std::string([[device name] UTF8String]); GpuInfo gpu_info; GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info); - RuntimeOptions options; - options.storage_precision = RuntimeOptions::Precision::FP32; - options.accumulator_precision = RuntimeOptions::Precision::FP32; + CalculationsPrecision precision = CalculationsPrecision::F32; CompiledModel compiled_model; - RETURN_IF_ERROR(Compile(graph_, gpu_info, options, &compiled_model)); + RETURN_IF_ERROR(Compile(graph_, gpu_info, precision, &compiled_model)); CompiledModel optimized_model; RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model)); @@ -97,7 +95,7 @@ absl::Status SingleOpModel::Invoke() { model:optimized_model inputBufferIDs:input_ids outputBufferIDs:output_ids - runtimeOptions:options]); + precision:precision]); std::map input_dimensions; std::map> input_buffers; for (auto& input : inputs_) { @@ -129,7 +127,7 @@ absl::Status SingleOpModel::Invoke() { id command_queue = [device newCommandQueue]; id command_buffer = [command_queue commandBuffer]; id command_encoder = [command_buffer computeCommandEncoder]; - [graph encodeWithEncoder:command_encoder inputOutputBuffers:inout_buffers encoderBlock:nil]; + [graph encodeWithEncoder:command_encoder inputOutputBuffers:inout_buffers]; [command_encoder endEncoding]; [command_buffer commit]; [command_buffer waitUntilCompleted]; @@ -193,16 +191,14 @@ absl::Status RunGraph(const std::vector& nodes, id de RETURN_IF_ERROR( ValidateOptimizeModel(inputBufferIDs, outputBufferIDs, raw_model, &optimized_model)); - RuntimeOptions options; - options.storage_precision = RuntimeOptions::Precision::FP32; - options.accumulator_precision = RuntimeOptions::Precision::FP32; + CalculationsPrecision precision = CalculationsPrecision::F32; TFLInferenceContext* graph = [[TFLInferenceContext alloc] init]; RETURN_IF_ERROR([graph compileModelWithDevice:device model:optimized_model inputBufferIDs:inputBufferIDs outputBufferIDs:outputBufferIDs - runtimeOptions:options]); + precision:precision]); std::map inputDimensions; std::map> inputBuffersCPU; std::map> inputBuffersGPU; @@ -239,7 +235,7 @@ absl::Status RunGraph(const std::vector& nodes, id de id commandQueue = [device newCommandQueue]; id commandBuffer = [commandQueue commandBuffer]; id commandEncoder = [commandBuffer computeCommandEncoder]; - [graph encodeWithEncoder:commandEncoder inputOutputBuffers:inputOutputBuffers encoderBlock:nil]; + [graph encodeWithEncoder:commandEncoder inputOutputBuffers:inputOutputBuffers]; [commandEncoder endEncoding]; [commandBuffer commit]; [commandBuffer waitUntilCompleted]; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm index 3d716ec217f..dd5f412d2fe 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv_test.mm @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" using ::tflite::gpu::ConvolutionTransposedAttributes; using ::tflite::gpu::BHWC; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm index 90d6c2e6d46..7f138a30124 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" using ::tflite::gpu::BHWC; @@ -151,10 +150,6 @@ using ::tflite::gpu::metal::CompareVectors; } } - tflite::gpu::metal::RuntimeOptions options; - options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - tflite::gpu::metal::Winograd4x4To36Attributes attr; attr.padding.prepended = tflite::gpu::HW(1, 1); attr.padding.appended = tflite::gpu::HW(1, 1); @@ -229,10 +224,6 @@ using ::tflite::gpu::metal::CompareVectors; attr.biases.shape = tflite::gpu::Linear(1); attr.biases.data.resize(1, 0.0f); - tflite::gpu::metal::RuntimeOptions options; - options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - tflite::gpu::OperationDef op_def; op_def.precision = tflite::gpu::CalculationsPrecision::F32; tflite::gpu::TensorDescriptor tensor_descriptor = tflite::gpu::TensorDescriptor{ @@ -304,10 +295,6 @@ using ::tflite::gpu::metal::CompareVectors; attr.biases.shape = tflite::gpu::Linear(1); attr.biases.data.resize(1, 0.0f); - tflite::gpu::metal::RuntimeOptions options; - options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; - tflite::gpu::OperationDef op_def; op_def.precision = tflite::gpu::CalculationsPrecision::F32; tflite::gpu::TensorDescriptor tensor_descriptor = tflite::gpu::TensorDescriptor{ diff --git a/tensorflow/lite/delegates/gpu/metal/runtime_options.h b/tensorflow/lite/delegates/gpu/metal/runtime_options.h deleted file mode 100644 index d8e8fe3dd92..00000000000 --- a/tensorflow/lite/delegates/gpu/metal/runtime_options.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ - -namespace tflite { -namespace gpu { -namespace metal { - -struct RuntimeOptions { - enum class Precision { - FP16, - FP32, - }; - // Buffer storage format. If FP32 then accumulator must be FP32. - Precision storage_precision = Precision::FP32; - // Accumulator precision. Defines the precision for convolutions. - Precision accumulator_precision = Precision::FP32; -}; - -} // namespace metal -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 229f898a0d9..60f8f361b35 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -45,10 +45,11 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/metal/inference_context.h" -#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/minimal_logging.h" + namespace tflite { namespace gpu { namespace metal { @@ -338,19 +339,17 @@ class Delegate { GpuInfo gpu_info; GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info); size_t storage_type_size; - RuntimeOptions runtime_options; + CalculationsPrecision precision; if (options_.allow_precision_loss) { storage_type_size = sizeof(HalfBits); - runtime_options.storage_precision = RuntimeOptions::Precision::FP16; if (gpu_info.IsRoundToNearestSupported()) { - runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16; + precision = CalculationsPrecision::F16; } else { - runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32; + precision = CalculationsPrecision::F32_F16; } } else { storage_type_size = sizeof(float); - runtime_options.storage_precision = RuntimeOptions::Precision::FP32; - runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32; + precision = CalculationsPrecision::F32; } // TODO(impjdi): Merge logic with above. @@ -435,7 +434,7 @@ class Delegate { // TODO(impjdi): Merge these. CompiledModel compiled_model; - RETURN_IF_ERROR(Compile(graph, gpu_info, runtime_options, &compiled_model)); + RETURN_IF_ERROR(Compile(graph, gpu_info, precision, &compiled_model)); CompiledModel optimized_model; RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model)); @@ -444,7 +443,7 @@ class Delegate { model:optimized_model inputBufferIDs:input_ids outputBufferIDs:output_ids - runtimeOptions:runtime_options]); + precision:precision]); return absl::OkStatus(); } @@ -454,12 +453,16 @@ class Delegate { // We need only synchronization so volatile works better than atomic which reads from global // memory each time. __block volatile bool buffer_completed = false; - __block id command_buffer; - __block id encoder = external_command_encoder_; + id command_buffer; + id encoder = external_command_encoder_; if (external_command_encoder_ == nil) { command_buffer = [command_queue_ commandBuffer]; encoder = [command_buffer computeCommandEncoder]; } + const bool flush = external_command_encoder_ == nil && + (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive || + options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive); + const int flush_period = 8; const bool is_quantized_model = !quant_conversion_map_.empty(); if (is_quantized_model) { @@ -479,36 +482,32 @@ class Delegate { shape:input.shape sourceBuffer:input_output_buffers_[input.id] convertedBuffer:bphwc4_buffers_[input.id]]; - if (external_command_encoder_ == nil) { - [encoder endEncoding]; - [command_buffer commit]; + } + if (flush) { + [encoder endEncoding]; + [command_buffer commit]; + } + + if (external_command_encoder_ != nil || + options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) { + // encoder == external_command_encoder_ + [inference_context_ encodeWithEncoder:encoder + inputOutputBuffers:bphwc4_buffers_]; + } else { + if (flush) { + [inference_context_ encodeWithCommandQueue:command_queue_ + inputOutputBuffers:bphwc4_buffers_ + flushPeriodically:flush_period]; command_buffer = [command_queue_ commandBuffer]; encoder = [command_buffer computeCommandEncoder]; + } else { + [encoder endEncoding]; + [inference_context_ encodeWithCommandBuffer:command_buffer + inputOutputBuffers:bphwc4_buffers_]; + encoder = [command_buffer computeCommandEncoder]; } } - [inference_context_ - encodeWithEncoder:encoder - inputOutputBuffers:bphwc4_buffers_ - encoderBlock:^(bool isLast) { - if (external_command_encoder_ != nil || - options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) { - return encoder; - } - if (isLast) { - if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { - [command_buffer addCompletedHandler:^(id) { - buffer_completed = true; - }]; - } - } else { - [encoder endEncoding]; - [command_buffer commit]; - command_buffer = [command_queue_ commandBuffer]; - encoder = [command_buffer computeCommandEncoder]; - } - return encoder; - }]; for (const auto& output : graph_outputs_) { if (output.set_externally) continue; if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue; @@ -520,6 +519,11 @@ class Delegate { if (external_command_encoder_ == nil) { [encoder endEncoding]; + if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { + [command_buffer addCompletedHandler:^(id) { + buffer_completed = true; + }]; + } [command_buffer commit]; if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { while (!buffer_completed) { @@ -531,16 +535,16 @@ class Delegate { // passive wait: this thread sleeps until GPU finishes. [command_buffer waitUntilCompleted]; } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) { - command_buffer = [command_queue_ commandBuffer]; - encoder = [command_buffer computeCommandEncoder]; - [encoder setComputePipelineState:signal_program_]; - [encoder setBuffer:signal_buffer_ offset:0 atIndex:0]; + id signal_cb = [command_queue_ commandBuffer]; + id signal_encoder = [signal_cb computeCommandEncoder]; + [signal_encoder setComputePipelineState:signal_program_]; + [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0]; signal_value_++; - [encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1]; - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) + [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1]; + [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - [encoder endEncoding]; - [command_buffer commit]; + [signal_encoder endEncoding]; + [signal_cb commit]; gpu_alarm_clock_->Start(); const int* signal_ptr = reinterpret_cast([signal_buffer_ contents]); while (signal_ptr[0] != signal_value_) { @@ -712,14 +716,6 @@ bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_ // Note: This function is not exposed in `metal_delegate.h`, but it's exposed in // `metal_delegate_internal.h`. -bool TFLGpuDelegateSetCommandEncoder( - TfLiteDelegate* delegate, id encoder) { - auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); - if (!metal_delegate) return false; - metal_delegate->SetCommandEncoder(encoder); - return true; -} - bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, id command_buffer) { auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); diff --git a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h index 1a33d046103..121caef450d 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h +++ b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h @@ -33,11 +33,6 @@ bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, id metal_buffer); -// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks -// into this encoder instead of the internal encoder. -bool TFLGpuDelegateSetCommandEncoder(TfLiteDelegate* delegate, - id encoder); - // Binds user-defined MTLCommandBuffer. The delegate puts all GPU tasks // into this buffer instead of the internal command buffer. bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md index 2fd4f078c4c..da48e8ff18e 100644 --- a/tensorflow/lite/g3doc/performance/post_training_quantization.md +++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md @@ -56,9 +56,29 @@ You can get further latency improvements, reductions in peak memory usage, and compatibility with integer only hardware devices or accelerators by making sure all model math is integer quantized. -For full integer quantization, you need to measure the dynamic range of -activations and inputs by supplying sample input data to the converter. Refer to -the `representative_dataset_gen()` function used in the following code. +For full integer quantization, you need to calibrate or estimate the range, i.e, +(min, max) of all floating-point tensors in the model. Unlike constant tensors +such as weights and biases, variable tensors such as model input, activations +(outputs of intermediate layers) and model output cannot be calibrated unless we +run a few inference cycles. As a result, the converter requires a representative +dataset to calibrate them. This dataset can be a small subset (around ~100-500 +samples) of the training or validation data. Refer to the +`representative_dataset()` function below. + +
+def representative_dataset():
+  for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
+    yield [data.astype(tf.float32)]
+
+ +For testing purposes, you can use a dummy dataset as follows: + +
+def representative_dataset():
+    for _ in range(100):
+      data = np.random.rand(1, 244, 244, 3)
+      yield [data.astype(np.float32)]
+ 
#### Integer with float fallback (using default float input/output) @@ -70,11 +90,7 @@ the following steps: import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] -def representative_dataset_gen(): - for _ in range(num_calibration_steps): - # Get sample input data as a numpy array in a method of your choosing. - yield [input] -converter.representative_dataset = representative_dataset_gen +converter.representative_dataset = representative_dataset tflite_quant_model = converter.convert() @@ -101,11 +117,7 @@ the following steps: import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] -def representative_dataset_gen(): - for _ in range(num_calibration_steps): - # Get sample input data as a numpy array in a method of your choosing. - yield [input] -converter.representative_dataset = representative_dataset_gen +converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 # or tf.uint8 converter.inference_output_type = tf.int8 # or tf.uint8 @@ -158,11 +170,7 @@ significantly, but only slightly increase model size.
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8]
 tflite_quant_model = converter.convert()
@@ -174,11 +182,7 @@ The following option should be added to the target_spec to allow this.
 
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-def representative_dataset_gen():
-  for _ in range(num_calibration_steps):
-    # Get sample input data as a numpy array in a method of your choosing.
-    yield [input]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
 tf.lite.OpsSet.TFLITE_BUILTINS]
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 4249c85238e..dcf6705a892 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -163,6 +163,12 @@ TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
 #endif
   void* lib_tf_internal =
       SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
+#if defined(_WIN32)
+  if (lib_tf_internal == nullptr) {
+    lib_tf_internal = SharedLibrary::LoadLibrary(
+        "_pywrap_tensorflow_interpreter_wrapper.pyd");
+  }
+#endif
   if (lib_tf_internal) {
     acquire_flex_delegate_func =
         reinterpret_cast(
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 8b6bad53206..fa3195062e6 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -37,40 +37,22 @@ cc_library(
     alwayslink = 1,
 )
 
-# This includes all ops. If you want a smaller binary, you should copy and
-# modify builtin_ops_jni.cc.  You should then link your binary against both
-# ":native_framework_only" and your own version of ":native_builtin_ops".
+# This includes all ops. If you want a smaller binary, you should use
+# tflite_custom_cc_library or tflite_custom_android_library rules.
 cc_library(
     name = "native",
-    srcs = [
-        "builtin_ops_jni.cc",
-    ],
-    hdrs = ["op_resolver.h"],
     copts = tflite_copts(),
     deps = [
         ":native_framework_only",
+        "//tensorflow/lite:create_op_resolver_with_builtin_ops",
         "//tensorflow/lite:framework",
-        "//tensorflow/lite/core/api",
-        "//tensorflow/lite/kernels:builtin_ops",
     ],
     alwayslink = 1,
 )
 
-# TODO(b/153652701): Generate this target to give CreateOpResolver a custom namespace.
-cc_library(
-    name = "selected_ops_jni",
-    srcs = ["selected_ops_jni.cc"],
-    hdrs = ["op_resolver.h"],
-    copts = tflite_copts(),
-    deps = [
-        "//tensorflow/lite:framework",
-    ],
-)
-
 exports_files(
     [
         "exported_symbols.lds",
         "version_script.lds",
-        "op_resolver.h",
     ],
 )
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 67fc33b3cfb..e0ab2724c5f 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -1141,14 +1141,17 @@ cc_test(
     srcs = ["numeric_verify_test.cc"],
     tags = ["tflite_nnapi"],
     deps = [
+        ":kernel_util",
         ":test_main",
         ":test_util",
         "//tensorflow/lite:framework",
+        "//tensorflow/lite/kernels/internal:reference",
         "//tensorflow/lite/kernels/internal:types",
         "//tensorflow/lite/schema:schema_fbs",
         "//third_party/eigen3",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
+        "@flatbuffers",
     ],
 )
 
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index cbe62516a52..41cc2ee63a9 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -1275,8 +1275,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
     gemm_input_data = im2col_data;
     gemm_input_shape = &im2col_shape;
   } else {
-    // TODO(aselle): We need to make sure to not send im2col if it is not
-    // needed.
     TFLITE_DCHECK(!im2col_data);
     gemm_input_data = input_data;
     gemm_input_shape = &input_shape;
@@ -7830,7 +7828,7 @@ inline void Transpose2D(const RuntimeShape& input_shape,
   }
 }
 
-// TODO(alanchiao): see if we can reduce the number
+// TODO(b/173718660): see if we can reduce the number
 // of lines of code in branching without affecting latency.
 template 
 inline void Transpose3D(const TransposeParams& params,
diff --git a/tensorflow/lite/kernels/numeric_verify.cc b/tensorflow/lite/kernels/numeric_verify.cc
index 5b4011f6649..ce1e491b1d0 100644
--- a/tensorflow/lite/kernels/numeric_verify.cc
+++ b/tensorflow/lite/kernels/numeric_verify.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include 
 #include 
 
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/dequantize.h"
 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
@@ -36,13 +37,19 @@ namespace ops {
 namespace custom {
 namespace numeric_verify {
 
+static constexpr const char kToleranceStr[] = "tolerance";
+static constexpr const char kDebugModeStr[] = "debug_mode";
+static constexpr const int kTemporaryDequantizedTensor = 0;
+
 struct OpContext {
   OpContext(TfLiteContext* context, TfLiteNode* node) {
     input = GetInput(context, node, 0);
     ref = GetInput(context, node, 1);
+    output = GetOutput(context, node, 0);
   }
   const TfLiteTensor* input;
   const TfLiteTensor* ref;
+  TfLiteTensor* output;
 };
 
 const int kTensorNotAllocated = -1;
@@ -50,21 +57,23 @@ const int kTensorNotAllocated = -1;
 struct OpData {
   // The percentage of the tensor value range. Must be a number less than 1.0.
   float tolerance;
-  // The abstract value allowed for the floating-point value difference.
-  float max_diff;
   // This boolean value is only used when the input tensor is constant.
   bool float_input_initialized;
   int cache_tensor_id = kTensorNotAllocated;
+  // This boolean value is for controlling the behavior of numeric verify op.
+  bool debug_mode;
 };
 
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   auto* op_data = new OpData();
   op_data->float_input_initialized = false;
 
-  // Get the tolerance parameter from the buffer. Use flexbuffers asMap if there
-  // multiple custom options.
-  const float* buffer_t = reinterpret_cast(buffer);
-  op_data->tolerance = *buffer_t;
+  const uint8_t* buffer_t = reinterpret_cast(buffer);
+  const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+  const float tolerance = m[kToleranceStr].AsFloat();
+  const bool debug_mode = m[kDebugModeStr].AsBool();
+  op_data->tolerance = tolerance;
+  op_data->debug_mode = debug_mode;
 
   return op_data;
 }
@@ -75,30 +84,19 @@ void Free(TfLiteContext* context, void* buffer) {
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
   OpData* op_data = reinterpret_cast(node->user_data);
 
   OpContext op_context(context, node);
 
+  const int num_output = (op_data->debug_mode) ? 1 : 0;
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), num_output);
+
   TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
                               op_context.input->type == kTfLiteInt8 ||
                               op_context.input->type == kTfLiteInt16 ||
                               op_context.input->type == kTfLiteFloat16);
   TF_LITE_ENSURE(context, op_context.ref->type == kTfLiteFloat32);
 
-  op_data->max_diff = op_data->tolerance * op_context.input->params.scale;
-  switch (op_context.input->type) {
-    case kTfLiteUInt8:
-    case kTfLiteInt8:
-      op_data->max_diff *= (1 << 8);
-      break;
-    case kTfLiteInt16:
-      op_data->max_diff *= (1 << 16);
-      break;
-    default:
-      break;
-  }
-
   // Allocate tensor to store the dequantized inputs.
   if (op_data->cache_tensor_id == kTensorNotAllocated) {
     TF_LITE_ENSURE_OK(
@@ -111,7 +109,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
   TfLiteTensor* dequantized;
   TF_LITE_ENSURE_OK(context,
-                    GetTemporarySafe(context, node, /*index=*/0, &dequantized));
+                    GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
+                                     &dequantized));
   dequantized->type = op_context.ref->type;
   dequantized->allocation_type = kTfLiteDynamic;
 
@@ -119,6 +118,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
                                  context, dequantized,
                                  TfLiteIntArrayCopy(op_context.input->dims)));
 
+  if (op_data->debug_mode) {
+    TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
+                                             &op_context.output));
+    op_context.output->type = kTfLiteFloat32;
+    op_context.output->allocation_type = kTfLiteArenaRwPersistent;
+    return context->ResizeTensor(context, op_context.output,
+                                 TfLiteIntArrayCopy(op_context.input->dims));
+  }
   return kTfLiteOk;
 }
 
@@ -146,7 +153,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   // Dequantize the input
   TfLiteTensor* dequantized;
   TF_LITE_ENSURE_OK(context,
-                    GetTemporarySafe(context, node, /*index=*/0, &dequantized));
+                    GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
+                                     &dequantized));
   auto status = builtin::dequantize::DequantizeImpl(
       context, node, op_context.input, dequantized);
   if (status != kTfLiteOk) {
@@ -157,15 +165,32 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     op_data->float_input_initialized = true;
   }
 
-  // If the tolerance is very small, we only display the stats of the diff.
-  if (op_data->tolerance < 0.1) {
+  // If the debug_mode is on, we don't throw any errors.
+  // We just calculate difference between float and quantized values, letting
+  // python debugger deal with the information.
+  if (op_data->debug_mode || op_data->tolerance < 0.1) {
+    const int num_output = (op_data->debug_mode) ? 1 : 0;
+    const int n = NumElements(dequantized);
+    if (op_data->debug_mode) {
+      TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
+                                               &op_context.output));
+      auto output_data = GetTensorData(op_context.output);
+      for (int i = 0; i < n; ++i) {
+        float dequant = GetTensorData(dequantized)[i];
+        float reference = GetTensorData(op_context.ref)[i];
+        output_data[i] = dequant - reference;
+      }
+    }
+    // These statistics logging was added to identify some errors in practice.
     std::vector diffs, temp;
-    diffs.reserve(NumElements(dequantized));
-    temp.reserve(NumElements(dequantized));
-    for (int i = 0; i < NumElements(op_context.ref); ++i) {
+    diffs.reserve(n);
+    temp.reserve(n);
+    diffs.resize(n);
+    temp.resize(n);
+    for (int i = 0; i < n; ++i) {
       float dequant = GetTensorData(dequantized)[i];
       float reference = GetTensorData(op_context.ref)[i];
-      diffs.push_back(dequant - reference);
+      diffs[i] = static_cast(dequant - reference);
     }
     double mean =
         std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
@@ -184,24 +209,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
         mean, max_diff, op_context.input->params.scale,
         op_context.input->params.zero_point);
     return kTfLiteOk;
-  }
-
-  // Verify the dequantized output.
-  auto max_diff = op_data->tolerance * op_context.input->params.scale;
-  for (int i = 0; i < NumElements(op_context.ref); ++i) {
-    int32_t value = GetQuantizedValue(op_context, i);
-    float dequant = GetTensorData(dequantized)[i];
-    float reference = GetTensorData(op_context.ref)[i];
-    float diff = std::abs(reference - dequant);
-    if (diff > max_diff) {
-      TF_LITE_KERNEL_LOG(
-          context,
-          "Mismatch: %f is quantized to %d with (%f, %d). "
-          "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
-          reference, value, op_context.input->params.scale,
-          op_context.input->params.zero_point, reference, dequant, diff,
-          max_diff, op_data->tolerance);
-      return kTfLiteError;
+  } else {
+    // Verify the dequantized output.
+    auto max_diff = op_data->tolerance * op_context.input->params.scale;
+    for (int i = 0; i < NumElements(op_context.ref); ++i) {
+      int32_t value = GetQuantizedValue(op_context, i);
+      float dequant = GetTensorData(dequantized)[i];
+      float reference = GetTensorData(op_context.ref)[i];
+      float diff = std::abs(reference - dequant);
+      if (diff > max_diff) {
+        TF_LITE_KERNEL_LOG(
+            context,
+            "Mismatch: %f is quantized to %d with (%f, %d). "
+            "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
+            reference, value, op_context.input->params.scale,
+            op_context.input->params.zero_point, reference, dequant, diff,
+            max_diff, op_data->tolerance);
+        return kTfLiteError;
+      }
     }
   }
   return kTfLiteOk;
diff --git a/tensorflow/lite/kernels/numeric_verify_test.cc b/tensorflow/lite/kernels/numeric_verify_test.cc
index 9fb2e559c37..e26f5607bb7 100644
--- a/tensorflow/lite/kernels/numeric_verify_test.cc
+++ b/tensorflow/lite/kernels/numeric_verify_test.cc
@@ -21,8 +21,11 @@ limitations under the License.
 #include 
 #include "absl/memory/memory.h"
 #include "third_party/eigen3/Eigen/Core"
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 #include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/test_util.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 
@@ -42,15 +45,25 @@ class NumericVerifyOpModel : public SingleOpModel {
  public:
   NumericVerifyOpModel(TensorType type, std::initializer_list shape,
                        float scale, int32_t zero_point, int version,
-                       float tolerance = 5.0) {
+                       float tolerance = 5.0, bool debug_mode = false) {
     const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
     input_ = AddInput(input_tensor_data);
     ref_ = AddInput({TensorType_FLOAT32, shape});
+    if (debug_mode) {
+      // The output tensor has the same shape with that of the input tensor.
+      output_ = AddOutput({TensorType_FLOAT32, shape});
+    }
 
     std::vector custom_options(sizeof(float));
-    memcpy(custom_options.data(), &tolerance, sizeof(float));
 
-    SetCustomOp("NUMERIC_VERIFY", custom_options,
+    flexbuffers::Builder fbb;
+    fbb.Map([&]() {
+      fbb.Float("tolerance", tolerance);
+      fbb.Bool("debug_mode", debug_mode);
+    });
+    fbb.Finish();
+
+    SetCustomOp("NUMERIC_VERIFY", fbb.GetBuffer(),
                 ops::custom::Register_NUMERIC_VERIFY);
 
     BuildInterpreter({GetShape(input_), GetShape(ref_)});
@@ -63,9 +76,12 @@ class NumericVerifyOpModel : public SingleOpModel {
     PopulateTensor(ref_, ref_data);
   }
 
+  std::vector GetOutput() { return ExtractVector(output_); }
+
  private:
   int input_;
   int ref_;
+  int output_;
 };
 
 TEST(NumericVerifyOpTest, Uint8) {
@@ -117,5 +133,18 @@ TEST(NumericVerifyOpFailedTest, Int8) {
   EXPECT_EQ(m.InvokeUnchecked(), kTfLiteError);
 }
 
+TEST(NumericVerifyOpDebugModeTest, Int8) {
+  // [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
+  NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, true);
+
+  // The 5th element is set to 0.
+  m.SetInputs({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127},
+                      {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
+  EXPECT_EQ(m.InvokeUnchecked(), kTfLiteOk);
+  // The 5th element has discrepancy -61.5 (=dequantized - reference=0-(61.5)).
+  EXPECT_THAT(
+      m.GetOutput(),
+      ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, -61.5, 0, 0, 0, 0})));
+}
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 511c7cbfb8e..9a6c28fd633 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -130,7 +130,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
              /* min_version = */ 1,
              /* max_version = */ 2);
-  AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE());
+  AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE(),
+             /* min_version = */ 1,
+             /* max_version = */ 2);
   AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
              /* min_version = */ 1,
              /* max_version = */ 4);
diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc
index 5803376fb29..f8ac5a43951 100644
--- a/tensorflow/lite/micro/kernels/quantize_test.cc
+++ b/tensorflow/lite/micro/kernels/quantize_test.cc
@@ -309,4 +309,23 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt32) {
                                   output_zero_point, output_quantized);
 }
 
+TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
+  constexpr int length = 10;
+  const int dims[] = {2, 2, 5};
+  const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31};
+  // TODO(b/155682734): Input scale must be smaller than output scale for
+  // xtensa.
+  const float input_scale = 0.4f;
+  const int input_zero_point = 0;
+  const float output_scale = 1.0f;
+  const int output_zero_point = 0;
+  int8_t output_quantized[length];
+  int8_t values_quantized[length];
+  int16_t input_quantized[length];
+  tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
+                                  input_zero_point, dims, values,
+                                  values_quantized, output_scale,
+                                  output_zero_point, output_quantized);
+}
+
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/tools/ci_build/test_all.sh b/tensorflow/lite/micro/tools/ci_build/test_all.sh
index c7a53cc8f09..fd6b384eb9b 100755
--- a/tensorflow/lite/micro/tools/ci_build/test_all.sh
+++ b/tensorflow/lite/micro/tools/ci_build/test_all.sh
@@ -24,8 +24,15 @@ ROOT_DIR=${SCRIPT_DIR}/../../../../..
 cd "${ROOT_DIR}"
 pwd
 
-make -f tensorflow/lite/micro/tools/make/Makefile \
-  clean clean_downloads
+make -f tensorflow/lite/micro/tools/make/Makefile clean_downloads DISABLE_DOWNLOADS=true
+
+
+make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn clean DISABLE_DOWNLOADS=true
+if [ -d tensorflow/lite/micro/tools/make/downloads ]; then
+  echo "ERROR: Downloads directory should not exist, but it does."
+  exit 1
+fi
+
 
 # We are moving away from having the downloads and installations be part of the
 # Makefile. As a result, we need to manually add the downloads in this script.
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index 3c7d472f0bd..d2b9adb9a8f 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -479,24 +479,33 @@ ALL_PROJECT_TARGETS :=
 ARDUINO_LIBRARY_TARGETS :=
 ARDUINO_LIBRARY_ZIPS :=
 
-# The download scripts require that the downloads directory already exist for
-# improved error checking. To accomodate that, we first create a downloads
-# directory.
-$(shell mkdir -p ${MAKEFILE_DIR}/downloads)
+# For some invocations of the makefile, it is useful to avoid downloads. This
+# can be achieved by explicitly passing in DISABLE_DOWNLOADS=true on the command
+# line. Note that for target-specific downloads (e.g. CMSIS) there will need to
+# be corresponding checking in the respecitve included makefiles (e.g.
+# ext_libs/cmsis_nn.inc)
+DISABLE_DOWNLOADS :=
 
-# Directly download the flatbuffers library.
-DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/flatbuffers_download.sh ${MAKEFILE_DIR}/downloads)
-ifneq ($(DOWNLOAD_RESULT), SUCCESS)
-  $(error Something went wrong with the flatbuffers download: $(DOWNLOAD_RESULT))
+ifneq ($(DISABLE_DOWNLOADS), true)
+  # The download scripts require that the downloads directory already exist for
+  # improved error checking. To accomodate that, we first create a downloads
+  # directory.
+  $(shell mkdir -p ${MAKEFILE_DIR}/downloads)
+
+  # Directly download the flatbuffers library.
+  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/flatbuffers_download.sh ${MAKEFILE_DIR}/downloads)
+  ifneq ($(DOWNLOAD_RESULT), SUCCESS)
+    $(error Something went wrong with the flatbuffers download: $(DOWNLOAD_RESULT))
+  endif
+
+  include $(MAKEFILE_DIR)/third_party_downloads.inc
+  THIRD_PARTY_DOWNLOADS :=
+  $(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,))
+  $(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
+  $(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
+  $(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
 endif
 
-include $(MAKEFILE_DIR)/third_party_downloads.inc
-THIRD_PARTY_DOWNLOADS :=
-$(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,))
-$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
-$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
-$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
-
 # The target-specific makefile must have a name that is exactly
 # TARGET_makefile.inc and is only needed for cross-compilation (i.e. when TARGET
 # is different from the HOST_OS).
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
index a778d0fbdd8..f03bc8f5c4b 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc
@@ -4,14 +4,16 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),)
         # CMSIS-NN optimizations not supported
     endif
 
-    # Setup CMSIS-NN lib and add required header files to microlite lib INCLUDE.
-    # Unless an external path is provided we force a download during the first phase of make so
-    # that the files exist prior to the call to recursive_find below. add_third_party_download
-    # prevents the use of wildcards and recursive_find in selecting which files to add to THIRD_PARTY_SRCS.
-    CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
-    CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
-    ifeq ($(CMSIS_PATH), $(CMSIS_DEFAULT_DOWNLOAD_PATH))
-      $(call $(or $(shell $(DOWNLOAD_SCRIPT) $(CMSIS_URL) $(CMSIS_MD5) $(CMSIS_PATH) >&2 && echo SUCCESS), $(error $(DOWNLOAD_SCRIPT) failed)))
+    ifneq ($(DISABLE_DOWNLOADS), true)
+      # Setup CMSIS-NN lib and add required header files to microlite lib INCLUDE.
+      # Unless an external path is provided we force a download during the first phase of make so
+      # that the files exist prior to the call to recursive_find below. add_third_party_download
+      # prevents the use of wildcards and recursive_find in selecting which files to add to THIRD_PARTY_SRCS.
+      CMSIS_DEFAULT_DOWNLOAD_PATH := $(MAKEFILE_DIR)/downloads/cmsis
+      CMSIS_PATH := $(CMSIS_DEFAULT_DOWNLOAD_PATH)
+      ifeq ($(CMSIS_PATH), $(CMSIS_DEFAULT_DOWNLOAD_PATH))
+        $(call $(or $(shell $(DOWNLOAD_SCRIPT) $(CMSIS_URL) $(CMSIS_MD5) $(CMSIS_PATH) >&2 && echo SUCCESS), $(error $(DOWNLOAD_SCRIPT) failed)))
+      endif
     endif
 
     THIRD_PARTY_CC_SRCS += \
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
index 25b034f4833..00fbf455a0b 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
@@ -1,6 +1,6 @@
 ifeq ($(TARGET_ARCH), hifi4)
 
-  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ext_libs/xtensa_download.sh ${MAKEFILE_DIR}/downloads)
+  DOWNLOAD_RESULT := $(shell $(MAKEFILE_DIR)/ext_libs/xtensa_download.sh ${MAKEFILE_DIR}/downloads hifi4)
   ifneq ($(DOWNLOAD_RESULT), SUCCESS)
     $(error Something went wrong with the xtensa download: $(DOWNLOAD_RESULT))
   endif
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
index 16303109a8e..a427f638d09 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
@@ -19,6 +19,7 @@
 # Called with four arguments:
 # 1 - Path to the downloads folder which is typically
 #     tensorflow/lite/micro/tools/make/downloads
+# 2 - Xtensa variant to download for (e.g. hifi4)
 #
 # This script is called from the Makefile and uses the following convention to
 # enable determination of sucess/failure:
@@ -39,28 +40,31 @@ if [ ! -d ${DOWNLOADS_DIR} ]; then
   exit 1
 fi
 
-# Name of the xa_nnlib directory once it is unzipped.
-HIFI4_XA_NNLIB_DIRNAME="xa_nnlib_hifi4"
-
-HIFI4_PATH=${DOWNLOADS_DIR}/${HIFI4_XA_NNLIB_DIRNAME}
-if [ -d ${HIFI4_PATH} ]; then
-  echo >&2 "${HIFI4_PATH} already exists, skipping the download."
+if [[ ${2} == "hifi4" ]]; then
+  LIBRARY_URL="http://mirror.tensorflow.org/github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_06_27.zip"
+  LIBRARY_DIRNAME="xa_nnlib_hifi4"
+  LIBRARY_MD5="45fdc1209a8da62ab568aa6040f7eabf"
 else
+  echo "Attempting to download an unsupported xtensa variant: ${2}"
+  exit 1
+fi
 
-  ZIP_ARCHIVE_NAME="xa_nnlib_06_27.zip"
-  HIFI4_URL="http://mirror.tensorflow.org/github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/${ZIP_ARCHIVE_NAME}"
-  HIFI4_MD5="45fdc1209a8da62ab568aa6040f7eabf"
+LIBRARY_INSTALL_PATH=${DOWNLOADS_DIR}/${LIBRARY_DIRNAME}
 
-  wget ${HIFI4_URL} -O /tmp/${ZIP_ARCHIVE_NAME} >&2
-  MD5=`md5sum /tmp/${ZIP_ARCHIVE_NAME} | awk '{print $1}'`
+if [ -d ${LIBRARY_INSTALL_PATH} ]; then
+  echo >&2 "${LIBRARY_INSTALL_PATH} already exists, skipping the download."
+else
+  TMP_ZIP_ARCHIVE_NAME="${LIBRARY_DIRNAME}.zip"
+  wget ${LIBRARY_URL} -O /tmp/${TMP_ZIP_ARCHIVE_NAME} >&2
+  MD5=`md5sum /tmp/${TMP_ZIP_ARCHIVE_NAME} | awk '{print $1}'`
 
-  if [[ ${MD5} != ${HIFI4_MD5} ]]
+  if [[ ${MD5} != ${LIBRARY_MD5} ]]
   then
-    echo "Bad checksum. Expected: ${HIFI4_MD5}, Got: ${MD5}"
+    echo "Bad checksum. Expected: ${LIBRARY_MD5}, Got: ${MD5}"
     exit 1
   fi
 
-  unzip -qo /tmp/${ZIP_ARCHIVE_NAME} -d ${DOWNLOADS_DIR} >&2
+  unzip -qo /tmp/${TMP_ZIP_ARCHIVE_NAME} -d ${DOWNLOADS_DIR} >&2
 fi
 
 echo "SUCCESS"
diff --git a/tensorflow/lite/micro/tools/make/flatbuffers_download.sh b/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
index 61f5f332131..8ac0c4df1cc 100755
--- a/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
+++ b/tensorflow/lite/micro/tools/make/flatbuffers_download.sh
@@ -69,6 +69,15 @@ function patch_to_avoid_strtod() {
   mv ${temp_flexbuffers_path} ${input_flexbuffers_path}
 }
 
+# The BUILD files in the downloaded folder result in an error with:
+#  bazel build tensorflow/lite/micro/...
+#
+# Parameters:
+#   $1 - path to the downloaded flatbuffers code.
+function delete_build_files() {
+  rm -f `find ${1} -name BUILD`
+}
+
 DOWNLOADED_FLATBUFFERS_PATH=${DOWNLOADS_DIR}/flatbuffers
 
 if [ -d ${DOWNLOADED_FLATBUFFERS_PATH} ]; then
@@ -91,6 +100,8 @@ else
   mv /tmp/flatbuffers-${ZIP_PREFIX} ${DOWNLOADED_FLATBUFFERS_PATH}
 
   patch_to_avoid_strtod ${DOWNLOADED_FLATBUFFERS_PATH}/include/flatbuffers/flexbuffers.h
+  delete_build_files ${DOWNLOADED_FLATBUFFERS_PATH}
+
 fi
 
 echo "SUCCESS"
diff --git a/tensorflow/lite/shared_library.h b/tensorflow/lite/shared_library.h
index a7bd91b3a0a..90b3dba3b70 100644
--- a/tensorflow/lite/shared_library.h
+++ b/tensorflow/lite/shared_library.h
@@ -36,6 +36,8 @@ class SharedLibrary {
     return reinterpret_cast(
         GetProcAddress(static_cast(handle), symbol));
   }
+  // Warning: Unlike dlsym(RTLD_DEFAULT), it doesn't search the symbol from
+  // dependent DLLs.
   static inline void* GetSymbol(const char* symbol) {
     return reinterpret_cast(GetProcAddress(nullptr, symbol));
   }
diff --git a/tensorflow/lite/testing/op_tests/depth_to_space.py b/tensorflow/lite/testing/op_tests/depth_to_space.py
index 9693a664c54..c4647e1110c 100644
--- a/tensorflow/lite/testing/op_tests/depth_to_space.py
+++ b/tensorflow/lite/testing/op_tests/depth_to_space.py
@@ -28,9 +28,15 @@ def make_depth_to_space_tests(options):
   """Make a set of tests to do depth_to_space."""
 
   test_parameters = [{
-      "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64],
+      "dtype": [tf.int32, tf.uint8, tf.int64],
       "input_shape": [[2, 3, 4, 16]],
       "block_size": [2, 4],
+      "fully_quantize": [False],
+  }, {
+      "dtype": [tf.float32],
+      "input_shape": [[2, 3, 4, 16]],
+      "block_size": [2, 4],
+      "fully_quantize": [True, False],
   }]
 
   def build_graph(parameters):
@@ -43,8 +49,15 @@ def make_depth_to_space_tests(options):
     return [input_tensor], [out]
 
   def build_inputs(parameters, sess, inputs, outputs):
-    input_values = create_tensor_data(parameters["dtype"],
-                                      parameters["input_shape"])
+    if not parameters["fully_quantize"]:
+      input_values = create_tensor_data(parameters["dtype"],
+                                        parameters["input_shape"])
+    else:
+      input_values = create_tensor_data(
+          parameters["dtype"],
+          parameters["input_shape"],
+          min_value=-1,
+          max_value=1)
     return [input_values], sess.run(
         outputs, feed_dict=dict(zip(inputs, [input_values])))
 
diff --git a/tensorflow/lite/testing/selective_build_test.cc b/tensorflow/lite/testing/selective_build_test.cc
index c3a0cf20ecc..ed0cc8d6701 100644
--- a/tensorflow/lite/testing/selective_build_test.cc
+++ b/tensorflow/lite/testing/selective_build_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
 #include 
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/create_op_resolver.h"
 #include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/java/src/main/native/op_resolver.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/model_builder.h"
 
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index 3aac9030f0a..99fe540d22f 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -190,6 +190,7 @@ $(wildcard tensorflow/lite/*/*/*/*/*/*tool.cc) \
 $(wildcard tensorflow/lite/kernels/*test_main.cc) \
 $(wildcard tensorflow/lite/kernels/*test_util*.cc) \
 $(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/*/mock*.c) \
+tensorflow/lite/create_op_resolver_with_selected_ops.cc \
 tensorflow/lite/tflite_with_xnnpack.cc \
 $(MINIMAL_SRCS)
 
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 45dff78ef92..d5ef9451659 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -106,6 +106,13 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
       property.version = 2;
       property.quantizable_int16 = false;
       break;
+    case BuiltinOperator_DEPTH_TO_SPACE:
+      property.inputs = {{0, {}}};
+      property.outputs = {{0, {}}};
+      property.restrict_same_input_output_scale = true;
+      property.version = 2;
+      property.quantizable_int16 = false;
+      break;
     case BuiltinOperator_SPLIT:
       // We skip input 0 since it is the split dim which is not real valued.
       property.inputs = {{1, {}}};
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index e657a66c328..c7c08c81c66 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -620,10 +620,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
     case BuiltinOperator_SELECT:
     case BuiltinOperator_RSQRT:
     case BuiltinOperator_SQUARED_DIFFERENCE:
-      if (op_sig.input_types.at(0) == TensorType_INT8) {
-        return 2;
-      }
-      return 1;
+    case BuiltinOperator_DEPTH_TO_SPACE:
     case BuiltinOperator_MIRROR_PAD:
       if (op_sig.input_types.at(0) == TensorType_INT8) {
         return 2;
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index bda02ec6473..dc3a5b93366 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -100,6 +100,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
               {{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
               {{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},
               {{BuiltinOperator_DEPTH_TO_SPACE, 1}, "2.1.0"},
+              {{BuiltinOperator_DEPTH_TO_SPACE, 2}, kPendingReleaseVersion},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"},
               {{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"},
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9985f274ba8..d7ddce89fd1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -387,16 +387,19 @@ cc_library(
     ],
 )
 
+# bfloat16_lib is shared with JAX, and must not depend on any other parts of
+# TensorFlow.
+# TODO(phawkins): move bfloat16 into its own pip package.
 cc_library(
     name = "bfloat16_lib",
     srcs = ["lib/core/bfloat16.cc"],
     hdrs = ["lib/core/bfloat16.h"],
     deps = [
         ":numpy_lib",
-        ":safe_ptr",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
+        "//tensorflow/core/platform:logging",
+        "//third_party/eigen3",
         "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -5597,25 +5600,25 @@ tf_python_pybind_extension(
     hdrs = [
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
+        # Using header directly is required to avoid ODR violations.
+        "util/stack_trace.h",
     ],
     # TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
     module_name = "_tf_stack",
     deps = [
-        ":stack_trace",
-        "//tensorflow/c:pywrap_required_hdrs",
-        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
-        "//tensorflow/core/framework:pywrap_required_hdrs",
-        "//tensorflow/core/platform:path",
-        "//third_party/python_runtime:headers",  # buildcleaner: keep
         "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/hash",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@pybind11",
-    ],
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "//tensorflow/c:pywrap_required_hdrs",
+        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//tensorflow/core/framework:pywrap_required_hdrs",
+        "//tensorflow/core/platform:path",
+    ] + if_static([
+        ":stack_trace",
+    ]),
 )
 
 tf_py_test(
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 22b4884dd71..6efba380ca0 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -36,9 +36,9 @@ import traceback
 
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
-from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 
 from tensorflow.python.eager import context
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 
 # pylint: enable=wildcard-import
 
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 25856834814..2b1c0a33530 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -153,9 +153,6 @@ def is_allowlisted(
   # The check for __code__ below is because isgeneratorfunction crashes
   # without one.
   if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
-    logging.warn(
-        'Entity %s appears to be a generator function. It will not be converted'
-        ' by AutoGraph.', o)
     logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
     return True
 
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index d8399a41f1c..306381347c7 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
       },
       py::return_value_policy::reference);
 
+  m.def(
+      "TF_LoadPluggableDeviceLibrary",
+      [](const char* library_filename) {
+        tensorflow::Safe_TF_StatusPtr status =
+            tensorflow::make_safe(TF_NewStatus());
+        auto output =
+            TF_LoadPluggableDeviceLibrary(library_filename, status.get());
+        tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+        return output;
+      },
+      py::return_value_policy::reference);
+
   m.def("TF_GetOpList", [](TF_Library* lib_handle) {
     TF_Buffer output_buffer = TF_GetOpList(lib_handle);
     return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
@@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
 
   m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
         py::call_guard());
+
+  m.def("TF_PluggableDeviceLibraryHandle",
+        TF_DeletePluggableDeviceLibraryHandle,
+        py::call_guard());
+
   m.def("TF_AddControlInput", TF_AddControlInput);
   m.def(
       "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index bc0ae540b7b..fb3a6560d1b 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 12, 7)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 12, 8)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 046a09f9638..5e004ad8583 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -1245,12 +1245,17 @@ class Context(object):
   def invoking_op_callbacks(self, value):
     self._thread_local_data.invoking_op_callbacks = value
 
-  def _initialize_physical_devices(self):
-    """Get local devices visible to the system."""
+  def _initialize_physical_devices(self, reinitialize=False):
+    """Gets local devices visible to the system.
+
+    Args:
+      reinitialize: If True, reinitializes self._physical_devices  so that
+        dynamic registered devices will also be visible to the python front-end.
+    """
     # We lazy initialize self._physical_devices since we do not want to do this
     # the constructor since the backend may not be initialized yet.
     with self._device_lock:
-      if self._physical_devices is not None:
+      if not reinitialize and self._physical_devices is not None:
         return
 
       devs = pywrap_tfe.TF_ListPhysicalDevices()
@@ -1269,6 +1274,12 @@ class Context(object):
     # Import device settings that may have been passed into the constructor
     self._import_config()
 
+  def reinitialize_physical_devices(self):
+    """Gets local devices visible to the system."""
+    # Reinitialize the physical device list after registering
+    # the pluggable device.
+    self._initialize_physical_devices(True)
+
   def list_physical_devices(self, device_type=None):
     """List local devices visible to the system.
 
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index ec33d318e6b..44336a760a4 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -51,84 +51,94 @@ from tensorflow.python.util.tf_export import tf_export
 
 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
 FREQUENT_TRACING_WARNING_THRESHOLD = 5
+FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
 
 
-class _CallCounter(object):
+class _FrequentTracingDetector(object):
   """Class keeping track of how many recent calls triggered tracing."""
 
-  __slots__ = ["_max_call_history", "_calls_per_tracings", "call_count"]
+  __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
 
-  def __init__(self, max_call_history):
-    self._max_call_history = max_call_history
+  def __init__(self):
     self._calls_per_tracings = []
-    self.call_count = 0
+    self._total_warning_count = 0
+    self._call_count = 0
 
-  def called_with_tracing(self):
-    self.call_count += 1
+  def called_with_tracing(self, function_name, omit_warning):
+    """Updates the list of most recent calls' tracing information.
+
+    Warns the user when recent calls caused retracing too often.
+
+    Args:
+      function_name: the python function being traced.
+      omit_warning: If 'True', this call will not warn the user even if
+        retracing happens too often.
+    """
+    self._call_count += 1
     self._calls_per_tracings.append(1)
 
     while self._calls_per_tracings:
-      if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
-        self.call_count -= self._calls_per_tracings.pop(0)
+      if (self._call_count - self._calls_per_tracings[0] >
+          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
+        self._call_count -= self._calls_per_tracings.pop(0)
       else:
         break
 
+    if (omit_warning or self._total_warning_count >=
+        FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
+      return
+    if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
+      self._total_warning_count += 1
+      logging.warning(
+          "{} out of the last {} calls to {} triggered tf.function "
+          "retracing. Tracing is expensive and the excessive number of "
+          "tracings could be due to (1) creating @tf.function repeatedly in "
+          "a loop, (2) passing tensors with different shapes, (3) passing "
+          "Python objects instead of tensors. For (1), please define your "
+          "@tf.function outside of the loop. For (2), @tf.function has "
+          "experimental_relax_shapes=True option that relaxes argument "
+          "shapes that can avoid unnecessary retracing. For (3), please "
+          "refer to "
+          "https://www.tensorflow.org/guide/function#controlling_retracing"
+          " and https://www.tensorflow.org/api_docs/python/tf/function for "
+          " more details.".format(
+              len(self._calls_per_tracings), self._call_count, function_name))
+
   def called_without_tracing(self):
     # We don't count tracing when users load a concrete function directly or
     # call get_concrete_function, so the first call can be not a tracing call.
     if not self._calls_per_tracings:
       self._calls_per_tracings = [0]
     self._calls_per_tracings[-1] += 1
-    self.call_count += 1
-
-  def get_tracing_count(self):
-    return len(self._calls_per_tracings)
+    self._call_count += 1
 
 
-class _FrequentTracingDetector(object):
-  """Class for frequent retracing detection and warning."""
+class _FrequentTracingDetectorManager(object):
+  """Class for the management of all _FrequentTracingDetector objects."""
 
-  __slots__ = ["_counters", "_lock"]
+  __slots__ = ["_detectors", "_lock"]
 
   def __init__(self):
-    self._counters = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
+    self._detectors = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
     self._lock = threading.Lock()
 
-  def _get_counter(self, key):
-    if key not in self._counters:
-      self._counters[key] = _CallCounter(
-          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
-    return self._counters[key]
+  def _get_detector(self, key):
+    if key not in self._detectors:
+      self._detectors[key] = _FrequentTracingDetector()
+    return self._detectors[key]
 
   def called_without_tracing(self, key):
     with self._lock:
-      counter = self._get_counter(key)
-      counter.called_without_tracing()
+      detector = self._get_detector(key)
+      detector.called_without_tracing()
 
   def called_with_tracing(self, key, function_name, omit_warning):
     with self._lock:
-      counter = self._get_counter(key)
-      counter.called_with_tracing()
-      if omit_warning:
-        return
-      if counter.get_tracing_count() >= FREQUENT_TRACING_WARNING_THRESHOLD:
-        logging.warning(
-            "{} out of the last {} calls to {} triggered tf.function "
-            "retracing. Tracing is expensive and the excessive number of "
-            "tracings could be due to (1) creating @tf.function repeatedly in "
-            "a loop, (2) passing tensors with different shapes, (3) passing "
-            "Python objects instead of tensors. For (1), please define your "
-            "@tf.function outside of the loop. For (2), @tf.function has "
-            "experimental_relax_shapes=True option that relaxes argument "
-            "shapes that can avoid unnecessary retracing. For (3), please "
-            "refer to "
-            "https://www.tensorflow.org/guide/function#controlling_retracing"
-            " and https://www.tensorflow.org/api_docs/python/tf/function for "
-            " more details.".format(counter.get_tracing_count(),
-                                    counter.call_count, function_name))
+      detector = self._get_detector(key)
+      detector.called_with_tracing(function_name, omit_warning)
 
 
-_frequent_tracing_detector = _FrequentTracingDetector()
+_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
 
 
 class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
@@ -794,10 +804,10 @@ class Function(object):
 
     if context.executing_eagerly():
       if without_tracing:
-        _frequent_tracing_detector.called_without_tracing(
+        _frequent_tracing_detector_manager.called_without_tracing(
             self._key_for_call_stats)
       else:
-        _frequent_tracing_detector.called_with_tracing(
+        _frequent_tracing_detector_manager.called_with_tracing(
             self._key_for_call_stats, self._python_function,
             self._omit_frequent_tracing_warning)
 
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 03970d8a322..2b1dad4ea70 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -956,6 +956,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
       self.assertLen(logs.output, 1)
       self.assertIn('Tracing is expensive', logs.output[0])
 
+  def test_retracing_warning_limits(self):
+
+    @def_function.function
+    def my_func(x):
+      return x
+
+    with self.assertLogs(level='WARN') as logs:
+      for i in range(10):
+        my_func(i)
+
+      self.assertLen(logs.output, 2)
+
   def test_experimental_get_tracing_count_function(self):
 
     @def_function.function
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 570e4af9574..3ac7025a4f2 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -389,13 +389,9 @@ class _DefinedFunction(object):
     variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
     variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access
 
-    collections_ref = {}
-    parent_collections_ref = ops.get_default_graph()._collections  # pylint: disable=protected-access
-    for key in variable_keys:
-      if key not in parent_collections_ref:
-        parent_collections_ref[key] = collections_ref[key] = []
-      else:
-        collections_ref[key] = parent_collections_ref[key]
+    parent_graph = ops.get_default_graph()
+    collections_ref = {
+        key: parent_graph.get_collection_ref(key) for key in variable_keys}
 
     temp_graph = func_graph_from_py_func(
         self._func,
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 69aa38dade3..243d33a84cf 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -18,16 +18,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
+
 from tensorflow.core.framework import function_pb2
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import tensor_shape_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.framework import versions_pb2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import versions
 from tensorflow.python.framework.func_graph import FuncGraph
+from tensorflow.python.ops import resource_variable_ops
 
 
 def function_def_to_graph(fdef, input_shapes=None):
@@ -84,6 +89,9 @@ def function_def_to_graph(fdef, input_shapes=None):
         func_graph.get_operation_by_name(fdef.control_ret[ret_name])
         for ret_name in fdef.signature.control_output
     ]
+
+    _set_handle_data(func_graph, fdef)
+
     for node in graph_def.node:
       output_shapes = node.attr.get("_output_shapes", None)
       if output_shapes is not None:
@@ -264,3 +272,19 @@ def _get_num_args(arg_def, node_def):
     return 1
   else:
     raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
+
+
+def _set_handle_data(func_graph, fdef):
+  """Adds handle data for resource type inputs and outputs."""
+  for tensor, arg_def in itertools.chain(
+      zip(func_graph.inputs, fdef.signature.input_arg),
+      zip(func_graph.outputs, fdef.signature.output_arg)):
+    if arg_def.handle_data:
+      shape_and_dtype = arg_def.handle_data[0]
+      handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
+      handle_data.is_set = True
+      handle_data.shape_and_type.append(
+          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
+              shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
+      resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
+          tensor, handle_data, True)
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index d1a0c261a55..73ef3f7d4e6 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -27,6 +27,7 @@ import sys
 
 from tensorflow.python import _pywrap_python_op_gen
 from tensorflow.python.client import pywrap_tf_session as py_tf
+from tensorflow.python.eager import context
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
@@ -159,6 +160,45 @@ def load_library(library_location):
         library_location)
 
 
+def load_pluggable_device_library(library_location):
+  """Loads a TensorFlow PluggableDevice plugin.
+
+  "library_location" can be a path to a specific shared object, or a folder.
+  If it is a folder, all shared objects will be loaded. when the library is
+  loaded, devices/kernels registered in the library via StreamExecutor C API
+  and Kernel/Op Registration C API are made available in TensorFlow process.
+
+  Args:
+    library_location: Path to the plugin or folder of plugins. Relative or
+      absolute filesystem path to a dynamic library file or folder.
+
+  Raises:
+    OSError: When the file to be loaded is not found.
+    RuntimeError: when unable to load the library.
+  """
+  if os.path.exists(library_location):
+    if os.path.isdir(library_location):
+      directory_contents = os.listdir(library_location)
+
+      pluggable_device_libraries = [
+          os.path.join(library_location, f)
+          for f in directory_contents
+          if _is_shared_object(f)
+      ]
+    else:
+      pluggable_device_libraries = [library_location]
+
+    for lib in pluggable_device_libraries:
+      py_tf.TF_LoadPluggableDeviceLibrary(lib)
+    # Reinitialized physical devices list after plugin registration.
+    context.context().reinitialize_physical_devices()
+  else:
+    raise OSError(
+        errno.ENOENT,
+        'The file or folder to load pluggable device libraries from does not '
+        'exist.', library_location)
+
+
 @tf_export('experimental.register_filesystem_plugin')
 def register_filesystem_plugin(plugin_location):
   """Loads a TensorFlow FileSystem plugin.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 376122b3c30..50dfcb40bb4 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -50,6 +50,7 @@ from tensorflow.python.eager import monitoring
 from tensorflow.python.eager import tape
 from tensorflow.python.framework import c_api_util
 from tensorflow.python.framework import composite_tensor
+from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -3292,18 +3293,18 @@ class Graph(object):
             continue
           # TODO(b/141471245): Fix the inconsistency when inputs of func graph
           # are appended during gradient computation of while/cond.
-          for input_tensor, _ in zip(func_graph_inputs,
-                                     function_def.signature.input_arg):
+          for input_tensor, arg_def in zip(func_graph_inputs,
+                                           function_def.signature.input_arg):
+            input_shapes.list.shape.add().CopyFrom(
+                input_tensor.get_shape().as_proto())
             if input_tensor.dtype == dtypes.resource:
-              # TODO(allenl): Save and restore handle data, then save the
-              # resource placeholder's shape. Right now some shape functions get
-              # confused if we set the shape of the resource placeholder (to a
-              # scalar of course) and there isn't any handle data.
-              input_shapes.list.shape.add().CopyFrom(
-                  tensor_shape.TensorShape(None).as_proto())
-            else:
-              input_shapes.list.shape.add().CopyFrom(
-                  input_tensor.get_shape().as_proto())
+              _copy_handle_data_to_arg_def(input_tensor, arg_def)
+
+          for output_tensor, arg_def in zip(func_graph.outputs,
+                                            function_def.signature.output_arg):
+            if output_tensor.dtype == dtypes.resource:
+              _copy_handle_data_to_arg_def(output_tensor, arg_def)
+
           for node in function_def.node_def:
             try:
               op = func_graph.get_operation_by_name(node.name)
@@ -6979,3 +6980,22 @@ def _get_enclosing_context(graph):
 
   if graph.building_function and hasattr(graph, "outer_graph"):
     return _get_enclosing_context(graph.outer_graph)
+
+
+def get_resource_handle_data(graph_op):
+  assert type(graph_op) == Tensor  # pylint: disable=unidiomatic-typecheck
+
+  handle_data = pywrap_tf_session.GetHandleShapeAndType(
+      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
+
+  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+      compat.as_bytes(handle_data))
+
+
+def _copy_handle_data_to_arg_def(tensor, arg_def):
+  handle_data = get_resource_handle_data(tensor)
+  if handle_data.shape_and_type:
+    shape_and_type = handle_data.shape_and_type[0]
+    proto = arg_def.handle_data.add()
+    proto.dtype = shape_and_type.dtype
+    proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4ba67989d1e..d7772f0accf 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -2579,6 +2579,12 @@ class TensorFlowTestCase(googletest.TestCase):
     self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
 
     msgs = [msg]
+    # np.allclose does not always work for our custom bfloat16 extension type
+    # when type promotions are involved, so we first cast any bfloat16 arrays
+    # to float32.
+    a_dtype = a.dtype
+    a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
+    b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
     if not np.allclose(a, b, rtol=rtol, atol=atol):
       # Adds more details to np.testing.assert_allclose.
       #
@@ -2602,7 +2608,7 @@ class TensorFlowTestCase(googletest.TestCase):
       msgs.append("not close rhs = {}".format(y))
       msgs.append("not close dif = {}".format(np.abs(x - y)))
       msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
-      msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
+      msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
       # TODO(xpan): There seems to be a bug:
       # tensorflow/compiler/tests:binary_ops_test pass with float32
       # nan even though the equal_nan is False by default internally.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 612a3855ff9..4e8fed1b4e3 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -236,6 +236,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":activations",
         ":backend",
         ":losses",
         "//tensorflow/python:array_ops",
diff --git a/tensorflow/python/keras/applications/mobilenet_v3.py b/tensorflow/python/keras/applications/mobilenet_v3.py
index ab396a2bf73..055d277a29b 100644
--- a/tensorflow/python/keras/applications/mobilenet_v3.py
+++ b/tensorflow/python/keras/applications/mobilenet_v3.py
@@ -61,7 +61,7 @@ BASE_DOCSTRING = """Instantiates the {name} architecture.
   The following table describes the performance of MobileNets:
   ------------------------------------------------------------------------
   MACs stands for Multiply Adds
-  
+
   |Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)|
   |---|---|---|---|---|
   | mobilenet_v3_large_1.0_224              | 217 | 5.4 |   75.6   |   51.2  |
@@ -77,11 +77,6 @@ BASE_DOCSTRING = """Instantiates the {name} architecture.
 
   Optionally loads weights pre-trained on ImageNet.
 
-  Note: each Keras Application expects a specific kind of input preprocessing.
-  For MobileNetV3, call
-  `tf.keras.applications.mobilenet_v3.preprocess_input` on your
-  inputs before passing them to the model.
-
   Arguments:
     input_shape: Optional shape tuple, to be specified if you would
       like to use a model with an input image resolution that is not
@@ -136,6 +131,10 @@ BASE_DOCSTRING = """Instantiates the {name} architecture.
       on the "top" layer. Ignored unless `include_top=True`. Set
       `classifier_activation=None` to return the logits of the "top" layer.
 
+  Call arguments:
+    inputs: A floating point `numpy.array` or a `tf.Tensor`, 4D with 3 color
+      channels, with values in the range [0, 255].
+
   Returns:
     A `keras.Model` instance.
 
@@ -555,6 +554,24 @@ def _inverted_res_block(x, expansion, filters, kernel_size, stride, se_ratio,
 
 @keras_export('keras.applications.mobilenet_v3.preprocess_input')
 def preprocess_input(x, data_format=None):  # pylint: disable=unused-argument
+  """A placeholder method for backward compatibility.
+
+  The preprocessing logic has been included in the mobilenet_v3 model
+  implementation. Users are no longer required to call this method to normalize
+  the input data. This method does nothing and only kept as a placeholder to
+  align the API surface between old and new version of model.
+
+  Args:
+    x: A floating point `numpy.array` or a `tf.Tensor`.
+    data_format: Optional data format of the image tensor/array. Defaults to
+      None, in which case the global setting
+      `tf.keras.backend.image_data_format()` is used (unless you changed it,
+      it defaults to "channels_last").{mode}
+
+  Returns:
+    Unchanged `numpy.array` or `tf.Tensor`.
+  """
+
   return x
 
 
@@ -563,8 +580,4 @@ def decode_predictions(preds, top=5):
   return imagenet_utils.decode_predictions(preds, top=top)
 
 
-preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
-    mode='',
-    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
-    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
 decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
index 0fc90150127..a8a00310465 100644
--- a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import functools
+import numpy as np
 import six
 
 import tensorflow as tf
@@ -37,6 +38,16 @@ def _get_metadata(name):
   }
 
 
+def _get_input_data(inputs):
+  if "input_shape" in inputs:
+    return tf.ones(inputs["input_shape"])
+  elif "input" in inputs:
+    return inputs["input"]
+  else:
+    raise ValueError("Please specificy either `input_shape` or `input`"
+                     "for the benchmark test")
+
+
 def _generate_benchmark_params(*params_list):
   benchmark_params = []
   for params in params_list:
@@ -66,19 +77,25 @@ class KerasLayerBenchmarks(six.with_metaclass(
   _benchmark_parameters = _generate_benchmark_params([
       ("Conv2D_small_shape", tf.keras.layers.Conv2D,
        {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (1, 1, 1, 1), 10000),
+       {"input_shape": (1, 1, 1, 1)}, 10),
       ("Conv2D_normal_shape", tf.keras.layers.Conv2D,
        {"filters": 1, "kernel_size": 1, "activation": "relu"},
-       (64, 28, 28, 3), 10000),
+       {"input_shape": (64, 28, 28, 3)}, 10),
       ("LSTM_small_shape", tf.keras.layers.LSTM,
-       {"units": 1}, (1, 1, 1), 10000),
+       {"units": 1}, {"input_shape": (1, 1, 1)}, 10),
       ("LSTM_normal_shape", tf.keras.layers.LSTM,
-       {"units": 4}, (32, 10, 8), 10000),
+       {"units": 4}, {"input_shape": (32, 10, 8)}, 10),
+      ("Embedding_small_shape", tf.keras.layers.Embedding,
+       {"input_dim": 1, "output_dim": 1, "input_length": 1},
+       {"input": np.random.randint(1, size=(1, 1))}, 10),
+      ("Embedding_normal_shape", tf.keras.layers.Embedding,
+       {"input_dim": 1000, "output_dim": 64, "input_length": 10},
+       {"input": np.random.randint(1000, size=(32, 10))}, 10),
   ])
 
-  def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters):
+  def benchmark_layer_call(self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
 
     fn = functools.partial(layer, x)
     name = _get_benchmark_name(self._get_name())
@@ -87,9 +104,9 @@ class KerasLayerBenchmarks(six.with_metaclass(
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_function(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(layer.call)
 
     fn = functools.partial(layer, x)
@@ -99,22 +116,25 @@ class KerasLayerBenchmarks(six.with_metaclass(
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_with_xla(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
+    name = _get_benchmark_name(self._get_name())
+    # TODO(b/173461426)
+    if layer_cls is tf.keras.layers.Embedding and name[-1] == "GPU":
+      return
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(
         layer.call, jit_compile=True)
 
     fn = functools.partial(layer, x)
-    name = _get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.xla"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
 
     fn = functools.partial(_layer_call_backward, layer, x)
     name = _get_benchmark_name(self._get_name())
@@ -123,9 +143,9 @@ class KerasLayerBenchmarks(six.with_metaclass(
     self.run_report(fn, num_iters, metadata)
 
   def benchmark_layer_call_backward_with_function(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(layer.call)
 
     fn = functools.partial(_layer_call_backward, layer, x)
@@ -151,17 +171,26 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
       #  {"units": 1}, (1, 1, 1), 10000),
       # ("LSTM_normal_shape", tf.keras.layers.LSTM,
       #  {"units": 4}, (32, 10, 8), 10000),
+      ("Embedding_small_shape", tf.keras.layers.Embedding,
+       {"input_dim": 1, "output_dim": 1, "input_length": 1},
+       {"input": np.random.randint(1, size=(1, 1))}, 10),
+      ("Embedding_normal_shape", tf.keras.layers.Embedding,
+       {"input_dim": 1000, "output_dim": 64, "input_length": 10},
+       {"input": np.random.randint(1000, size=(32, 10))}, 10),
   ])
 
   def benchmark_layer_call_backward_with_xla(
-      self, layer_cls, layer_args, input_shape, num_iters):
+      self, layer_cls, layer_args, inputs, num_iters):
+    name = _get_benchmark_name(self._get_name())
+    # TODO(b/173461426)
+    if layer_cls is tf.keras.layers.Embedding and name[-1] == "GPU":
+      return
     layer = layer_cls(**layer_args)
-    x = tf.ones(input_shape)
+    x = _get_input_data(inputs)
     layer.call = tf.function(
         layer.call, jit_compile=True)
 
     fn = functools.partial(_layer_call_backward, layer, x)
-    name = _get_benchmark_name(self._get_name())
     metadata = {"implementation": name[0] + ".layer.call.backward.xla"}
     metadata.update(_get_metadata(name))
     self.run_report(fn, num_iters, metadata)
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index f05fb910a72..5b7019793e6 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -37,8 +37,8 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.activations import sigmoid
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import keras_tensor
@@ -2134,7 +2134,7 @@ class AUC(Metric):
     label_weights = None if self.multi_label else self.label_weights
 
     if self._from_logits:
-      y_pred = sigmoid(y_pred)
+      y_pred = activations.sigmoid(y_pred)
 
     with ops.control_dependencies(deps):
       return metrics_utils.update_confusion_matrix_variables(
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 99f70c16999..c53f196ecb9 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -182,23 +182,6 @@ class SparseXentTest(test.TestCase):
           np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
           np.array([0, 3]).astype(label_dtype))
 
-  def testBfloat16(self):
-    for label_dtype in np.int32, np.int64:
-      np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
-                                                 4.]]).astype(np.float32)
-      np_labels = np.array([0, 3]).astype(label_dtype)
-      np_loss, np_backprop = self._npXent(np_features, np_labels)
-
-      np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
-      np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
-      np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
-      with self.cached_session(use_gpu=False):
-        loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
-            np_features_bf16, np_labels)
-        tf_loss, tf_backprop = self.evaluate([loss, backprop])
-      self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
-      self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)
-
   def testHalf(self):
     for label_dtype in np.int32, np.int64:
       self._testXent(
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 31def39a98e..3f4b2a52d25 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -13,64 +13,46 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include 
-
 #include "tensorflow/python/lib/core/bfloat16.h"
 
-#include "tensorflow/core/framework/numeric_types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include 
+#include 
+// Place `` before  to avoid a build failure in macOS.
+#include 
+
+#include "absl/strings/str_cat.h"
+#include "third_party/eigen3/Eigen/Core"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/python/lib/core/numpy.h"
-#include "tensorflow/python/lib/core/safe_ptr.h"
 
 namespace tensorflow {
 namespace {
 
-// Workarounds for Python 2 vs 3 API differences.
-#if PY_MAJOR_VERSION < 3
+using bfloat16 = Eigen::bfloat16;
 
-PyObject* MakePyString(const string& s) {
-  return PyString_FromString(s.c_str());
+struct PyDecrefDeleter {
+  void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+// Safe container for an owned PyObject. On destruction, the reference count of
+// the contained object will be decremented.
+using Safe_PyObjectPtr = std::unique_ptr;
+Safe_PyObjectPtr make_safe(PyObject* object) {
+  return Safe_PyObjectPtr(object);
 }
 
-typedef long HashType;  // NOLINT
-
-bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); }
-
-PyObject* TfPyInt_FromLong(long x) {  // NOLINT
-  return PyInt_FromLong(x);
-}
-
-long TfPyInt_AsLong(PyObject* x) {  // NOLINT
-  return PyInt_AsLong(x);
-}
-
-#else  // PY_MAJOR_VERSION < 3
-
-PyObject* MakePyString(const string& s) {
-  return PyUnicode_FromString(s.c_str());
-}
-
-bool TfPyInt_Check(PyObject* object) {
+bool PyLong_CheckNoOverflow(PyObject* object) {
   if (!PyLong_Check(object)) {
-    return 0;
+    return false;
   }
   int overflow = 0;
   PyLong_AsLongAndOverflow(object, &overflow);
   return (overflow == 0);
 }
 
-PyObject* TfPyInt_FromLong(long x) {  // NOLINT
-  return PyLong_FromLong(x);
-}
-
-long TfPyInt_AsLong(PyObject* x) {  // NOLINT
-  return PyLong_AsLong(x);
-}
-
-typedef Py_hash_t HashType;
-
-#endif  // PY_MAJOR_VERSION < 3
+// Registered numpy type ID. Global variable populated by the registration code.
+// Protected by the GIL.
+int npy_bfloat16 = -1;
 
 // Forward declaration.
 extern PyTypeObject PyBfloat16_Type;
@@ -105,7 +87,7 @@ Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
 
 // Converts a Python object to a bfloat16 value. Returns true on success,
 // returns false and reports a Python error on failure.
-bool AsBfloat16(PyObject* arg, bfloat16* output) {
+bool CastToBfloat16(PyObject* arg, bfloat16* output) {
   if (PyBfloat16_Check(arg)) {
     *output = PyBfloat16_Bfloat16(arg);
     return true;
@@ -119,8 +101,8 @@ bool AsBfloat16(PyObject* arg, bfloat16* output) {
     *output = bfloat16(d);
     return true;
   }
-  if (TfPyInt_Check(arg)) {
-    long l = TfPyInt_AsLong(arg);  // NOLINT
+  if (PyLong_CheckNoOverflow(arg)) {
+    long l = PyLong_AsLong(arg);  // NOLINT
     if (PyErr_Occurred()) {
       return false;
     }
@@ -128,14 +110,46 @@ bool AsBfloat16(PyObject* arg, bfloat16* output) {
     *output = bfloat16(static_cast(l));
     return true;
   }
+  if (PyArray_IsScalar(arg, Half)) {
+    Eigen::half f;
+    PyArray_ScalarAsCtype(arg, &f);
+    *output = bfloat16(f);
+    return true;
+  }
   if (PyArray_IsScalar(arg, Float)) {
     float f;
     PyArray_ScalarAsCtype(arg, &f);
     *output = bfloat16(f);
     return true;
   }
-  PyErr_Format(PyExc_TypeError, "expected number, got %s",
-               arg->ob_type->tp_name);
+  if (PyArray_IsScalar(arg, Double)) {
+    double f;
+    PyArray_ScalarAsCtype(arg, &f);
+    *output = bfloat16(f);
+    return true;
+  }
+  if (PyArray_IsZeroDim(arg)) {
+    Safe_PyObjectPtr ref;
+    PyArrayObject* arr = reinterpret_cast(arg);
+    if (PyArray_TYPE(arr) != npy_bfloat16) {
+      ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
+      if (PyErr_Occurred()) {
+        return false;
+      }
+      arg = ref.get();
+      arr = reinterpret_cast(arg);
+    }
+    *output = *reinterpret_cast(PyArray_DATA(arr));
+    return true;
+  }
+  return false;
+}
+
+bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
+  if (PyBfloat16_Check(arg)) {
+    *output = PyBfloat16_Bfloat16(arg);
+    return true;
+  }
   return false;
 }
 
@@ -149,7 +163,7 @@ PyObject* PyBfloat16_Float(PyObject* self) {
 PyObject* PyBfloat16_Int(PyObject* self) {
   bfloat16 x = PyBfloat16_Bfloat16(self);
   long y = static_cast(x);  // NOLINT
-  return TfPyInt_FromLong(y);
+  return PyLong_FromLong(y);
 }
 
 // Negates a PyBfloat16.
@@ -158,28 +172,43 @@ PyObject* PyBfloat16_Negative(PyObject* self) {
   return PyBfloat16_FromBfloat16(-x).release();
 }
 
-// Binary arithmetic operators on PyBfloat16 values.
-#define BFLOAT16_BINOP(name, op)                                  \
-  PyObject* PyBfloat16_##name(PyObject* a, PyObject* b) {         \
-    bfloat16 x, y;                                                \
-    if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr; \
-    bfloat16 z = x op y;                                          \
-    return PyBfloat16_FromBfloat16(z).release();                  \
+PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x + y).release();
   }
-BFLOAT16_BINOP(Add, +)
-BFLOAT16_BINOP(Subtract, -)
-BFLOAT16_BINOP(Multiply, *)
-BFLOAT16_BINOP(Divide, /)
-#undef BFLOAT16_BINOP
+  return PyArray_Type.tp_as_number->nb_add(a, b);
+}
+
+PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x - y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_subtract(a, b);
+}
+
+PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x * y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_multiply(a, b);
+}
+
+PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
+  bfloat16 x, y;
+  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
+    return PyBfloat16_FromBfloat16(x / y).release();
+  }
+  return PyArray_Type.tp_as_number->nb_true_divide(a, b);
+}
 
 // Python number methods for PyBfloat16 objects.
 PyNumberMethods PyBfloat16_AsNumber = {
     PyBfloat16_Add,       // nb_add
     PyBfloat16_Subtract,  // nb_subtract
     PyBfloat16_Multiply,  // nb_multiply
-#if PY_MAJOR_VERSION < 3
-    PyBfloat16_Divide,  // nb_divide
-#endif
     nullptr,              // nb_remainder
     nullptr,              // nb_divmod
     nullptr,              // nb_power
@@ -193,27 +222,13 @@ PyNumberMethods PyBfloat16_AsNumber = {
     nullptr,              // nb_and
     nullptr,              // nb_xor
     nullptr,              // nb_or
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_coerce
-#endif
-    PyBfloat16_Int,  // nb_int
-#if PY_MAJOR_VERSION < 3
-    PyBfloat16_Int,  // nb_long
-#else
-    nullptr,  // reserved
-#endif
-    PyBfloat16_Float,  // nb_float
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_oct
-    nullptr,  // nb_hex
-#endif
+    PyBfloat16_Int,       // nb_int
+    nullptr,              // reserved
+    PyBfloat16_Float,     // nb_float
 
     nullptr,  // nb_inplace_add
     nullptr,  // nb_inplace_subtract
     nullptr,  // nb_inplace_multiply
-#if PY_MAJOR_VERSION < 3
-    nullptr,  // nb_inplace_divide
-#endif
     nullptr,  // nb_inplace_remainder
     nullptr,  // nb_inplace_power
     nullptr,  // nb_inplace_lshift
@@ -222,11 +237,11 @@ PyNumberMethods PyBfloat16_AsNumber = {
     nullptr,  // nb_inplace_xor
     nullptr,  // nb_inplace_or
 
-    nullptr,            // nb_floor_divide
-    PyBfloat16_Divide,  // nb_true_divide
-    nullptr,            // nb_inplace_floor_divide
-    nullptr,            // nb_inplace_true_divide
-    nullptr,            // nb_index
+    nullptr,                // nb_floor_divide
+    PyBfloat16_TrueDivide,  // nb_true_divide
+    nullptr,                // nb_inplace_floor_divide
+    nullptr,                // nb_inplace_true_divide
+    nullptr,                // nb_index
 };
 
 // Constructs a new PyBfloat16.
@@ -243,22 +258,32 @@ PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
   }
   PyObject* arg = PyTuple_GetItem(args, 0);
 
+  bfloat16 value;
   if (PyBfloat16_Check(arg)) {
     Py_INCREF(arg);
     return arg;
-  } else {
-    bfloat16 value;
-    if (!AsBfloat16(arg, &value)) {
-      return nullptr;
-    }
+  } else if (CastToBfloat16(arg, &value)) {
     return PyBfloat16_FromBfloat16(value).release();
+  } else if (PyArray_Check(arg)) {
+    PyArrayObject* arr = reinterpret_cast(arg);
+    if (PyArray_TYPE(arr) != npy_bfloat16) {
+      return PyArray_Cast(arr, npy_bfloat16);
+    } else {
+      Py_INCREF(arg);
+      return arg;
+    }
   }
+  PyErr_Format(PyExc_TypeError, "expected number, got %s",
+               arg->ob_type->tp_name);
+  return nullptr;
 }
 
 // Comparisons on PyBfloat16s.
 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
   bfloat16 x, y;
-  if (!AsBfloat16(a, &x) || !AsBfloat16(b, &y)) return nullptr;
+  if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
+    return PyGenericArrType_Type.tp_richcompare(a, b, op);
+  }
   bool result;
   switch (op) {
     case Py_LT:
@@ -288,81 +313,77 @@ PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
 // Implementation of repr() for PyBfloat16.
 PyObject* PyBfloat16_Repr(PyObject* self) {
   bfloat16 x = reinterpret_cast(self)->value;
-  string v = strings::StrCat("bfloat16(", static_cast(x), ")");
-  return MakePyString(v);
+  std::string v = absl::StrCat(static_cast(x));
+  return PyUnicode_FromString(v.c_str());
 }
 
 // Implementation of str() for PyBfloat16.
 PyObject* PyBfloat16_Str(PyObject* self) {
   bfloat16 x = reinterpret_cast(self)->value;
-  string v = strings::StrCat(static_cast(x));
-  return MakePyString(v);
+  std::string v = absl::StrCat(static_cast(x));
+  return PyUnicode_FromString(v.c_str());
 }
 
 // Hash function for PyBfloat16. We use the identity function, which is a weak
 // hash function.
-HashType PyBfloat16_Hash(PyObject* self) {
+Py_hash_t PyBfloat16_Hash(PyObject* self) {
   bfloat16 x = reinterpret_cast(self)->value;
   return x.value;
 }
 
 // Python type for PyBfloat16 objects.
 PyTypeObject PyBfloat16_Type = {
-#if PY_MAJOR_VERSION < 3
-    PyObject_HEAD_INIT(nullptr) 0,  // ob_size
-#else
-    PyVarObject_HEAD_INIT(nullptr, 0)
-#endif
-    "bfloat16",          // tp_name
-    sizeof(PyBfloat16),  // tp_basicsize
-    0,                   // tp_itemsize
-    nullptr,             // tp_dealloc
+    PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16",  // tp_name
+    sizeof(PyBfloat16),                            // tp_basicsize
+    0,                                             // tp_itemsize
+    nullptr,                                       // tp_dealloc
 #if PY_VERSION_HEX < 0x03080000
     nullptr,  // tp_print
 #else
     0,  // tp_vectorcall_offset
 #endif
-    nullptr,                                   // tp_getattr
-    nullptr,                                   // tp_setattr
-    nullptr,                                   // tp_compare / tp_reserved
-    PyBfloat16_Repr,                           // tp_repr
-    &PyBfloat16_AsNumber,                      // tp_as_number
-    nullptr,                                   // tp_as_sequence
-    nullptr,                                   // tp_as_mapping
-    PyBfloat16_Hash,                           // tp_hash
-    nullptr,                                   // tp_call
-    PyBfloat16_Str,                            // tp_str
-    nullptr,                                   // tp_getattro
-    nullptr,                                   // tp_setattro
-    nullptr,                                   // tp_as_buffer
-    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  // tp_flags
-    "bfloat16 floating-point values",          // tp_doc
-    nullptr,                                   // tp_traverse
-    nullptr,                                   // tp_clear
-    PyBfloat16_RichCompare,                    // tp_richcompare
-    0,                                         // tp_weaklistoffset
-    nullptr,                                   // tp_iter
-    nullptr,                                   // tp_iternext
-    nullptr,                                   // tp_methods
-    nullptr,                                   // tp_members
-    nullptr,                                   // tp_getset
-    nullptr,                                   // tp_base
-    nullptr,                                   // tp_dict
-    nullptr,                                   // tp_descr_get
-    nullptr,                                   // tp_descr_set
-    0,                                         // tp_dictoffset
-    nullptr,                                   // tp_init
-    nullptr,                                   // tp_alloc
-    PyBfloat16_New,                            // tp_new
-    nullptr,                                   // tp_free
-    nullptr,                                   // tp_is_gc
-    nullptr,                                   // tp_bases
-    nullptr,                                   // tp_mro
-    nullptr,                                   // tp_cache
-    nullptr,                                   // tp_subclasses
-    nullptr,                                   // tp_weaklist
-    nullptr,                                   // tp_del
-    0,                                         // tp_version_tag
+    nullptr,               // tp_getattr
+    nullptr,               // tp_setattr
+    nullptr,               // tp_compare / tp_reserved
+    PyBfloat16_Repr,       // tp_repr
+    &PyBfloat16_AsNumber,  // tp_as_number
+    nullptr,               // tp_as_sequence
+    nullptr,               // tp_as_mapping
+    PyBfloat16_Hash,       // tp_hash
+    nullptr,               // tp_call
+    PyBfloat16_Str,        // tp_str
+    nullptr,               // tp_getattro
+    nullptr,               // tp_setattro
+    nullptr,               // tp_as_buffer
+                           // tp_flags
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
+    "bfloat16 floating-point values",  // tp_doc
+    nullptr,                           // tp_traverse
+    nullptr,                           // tp_clear
+    PyBfloat16_RichCompare,            // tp_richcompare
+    0,                                 // tp_weaklistoffset
+    nullptr,                           // tp_iter
+    nullptr,                           // tp_iternext
+    nullptr,                           // tp_methods
+    nullptr,                           // tp_members
+    nullptr,                           // tp_getset
+    nullptr,                           // tp_base
+    nullptr,                           // tp_dict
+    nullptr,                           // tp_descr_get
+    nullptr,                           // tp_descr_set
+    0,                                 // tp_dictoffset
+    nullptr,                           // tp_init
+    nullptr,                           // tp_alloc
+    PyBfloat16_New,                    // tp_new
+    nullptr,                           // tp_free
+    nullptr,                           // tp_is_gc
+    nullptr,                           // tp_bases
+    nullptr,                           // tp_mro
+    nullptr,                           // tp_cache
+    nullptr,                           // tp_subclasses
+    nullptr,                           // tp_weaklist
+    nullptr,                           // tp_del
+    0,                                 // tp_version_tag
 };
 
 // Numpy support
@@ -370,31 +391,32 @@ PyTypeObject PyBfloat16_Type = {
 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
 
 PyArray_Descr NPyBfloat16_Descr = {
-    PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type,  // typeobj
+    PyObject_HEAD_INIT(nullptr)  //
+                                 /*typeobj=*/
+    (&PyBfloat16_Type),
     // We must register bfloat16 with a kind other than "f", because numpy
     // considers two types with the same kind and size to be equal, but
     // float16 != bfloat16.
-    'V',  // kind
+    // The downside of this is that NumPy scalar promotion does not work with
+    // bfloat16 values.
+    /*kind=*/'V',
     // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
     // character is unique.
-    'E',                                                  // type
-    '=',                                                  // byteorder
-    NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,  // hasobject
-    0,                                                    // type_num
-    sizeof(bfloat16),                                     // elsize
-    alignof(bfloat16),                                    // alignment
-    nullptr,                                              // subarray
-    nullptr,                                              // fields
-    nullptr,                                              // names
-    &NPyBfloat16_ArrFuncs,                                // f
-    nullptr,                                              // metadata
-    nullptr,                                              // c_metadata
-    -1,                                                   // hash
+    /*type=*/'E',
+    /*byteorder=*/'=',
+    /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
+    /*type_num=*/0,
+    /*elsize=*/sizeof(bfloat16),
+    /*alignment=*/alignof(bfloat16),
+    /*subarray=*/nullptr,
+    /*fields=*/nullptr,
+    /*names=*/nullptr,
+    /*f=*/&NPyBfloat16_ArrFuncs,
+    /*metadata=*/nullptr,
+    /*c_metadata=*/nullptr,
+    /*hash=*/-1,  // -1 means "not computed yet".
 };
 
-// Registered numpy type ID. Global variable populated by the registration code.
-int npy_bfloat16_ = -1;
-
 // Implementations of NumPy array methods.
 
 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
@@ -405,7 +427,11 @@ PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
 
 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
   bfloat16 x;
-  if (!AsBfloat16(item, &x)) return -1;
+  if (!CastToBfloat16(item, &x)) {
+    PyErr_Format(PyExc_TypeError, "expected number, got %s",
+                 item->ob_type->tp_name);
+    return -1;
+  }
   memcpy(data, &x, sizeof(bfloat16));
   return 0;
 }
@@ -486,16 +512,183 @@ int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
   return 0;
 }
 
+void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
+                         void* op, npy_intp n, void* arr) {
+  char* c1 = reinterpret_cast(ip1);
+  char* c2 = reinterpret_cast(ip2);
+  float acc = 0.0f;
+  for (npy_intp i = 0; i < n; ++i) {
+    bfloat16* const b1 = reinterpret_cast(c1);
+    bfloat16* const b2 = reinterpret_cast(c2);
+    acc += static_cast(*b1) * static_cast(*b2);
+    c1 += is1;
+    c2 += is2;
+  }
+  bfloat16* out = reinterpret_cast(op);
+  *out = static_cast(acc);
+}
+
+int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
+  bfloat16 b1 = *reinterpret_cast(v1);
+  bfloat16 b2 = *reinterpret_cast(v2);
+  if (b1 < b2) {
+    return -1;
+  }
+  if (b1 > b2) {
+    return 1;
+  }
+  return 0;
+}
+
+int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
+                           void* arr) {
+  const bfloat16* bdata = reinterpret_cast(data);
+  float max_val = -std::numeric_limits::infinity();
+  for (npy_intp i = 0; i < n; ++i) {
+    if (static_cast(bdata[i]) > max_val) {
+      max_val = static_cast(bdata[i]);
+      *max_ind = i;
+    }
+  }
+  return 0;
+}
+
+int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
+                           void* arr) {
+  const bfloat16* bdata = reinterpret_cast(data);
+  float min_val = std::numeric_limits::infinity();
+  for (npy_intp i = 0; i < n; ++i) {
+    if (static_cast(bdata[i]) < min_val) {
+      min_val = static_cast(bdata[i]);
+      *min_ind = i;
+    }
+  }
+  return 0;
+}
+
 // NumPy casts
 
+template 
+struct TypeDescriptor {
+  // typedef ... T;  // Representation type in memory for NumPy values of type
+  // static int Dtype() { return NPY_...; }  // Numpy type number for T.
+};
+
+template <>
+struct TypeDescriptor {
+  typedef bfloat16 T;
+  static int Dtype() { return npy_bfloat16; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef uint8 T;
+  static int Dtype() { return NPY_UINT8; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef uint16 T;
+  static int Dtype() { return NPY_UINT16; }
+};
+
+// We register "int", "long", and "long long" types for portability across
+// Linux, where "int" and "long" are the same type, and Windows, where "long"
+// and "longlong" are the same type.
+template <>
+struct TypeDescriptor {
+  typedef unsigned int T;
+  static int Dtype() { return NPY_UINT; }
+};
+
+template <>
+struct TypeDescriptor {  // NOLINT
+  typedef unsigned long T;              // NOLINT
+  static int Dtype() { return NPY_ULONG; }
+};
+
+template <>
+struct TypeDescriptor {  // NOLINT
+  typedef unsigned long long T;              // NOLINT
+  static int Dtype() { return NPY_ULONGLONG; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef int8 T;
+  static int Dtype() { return NPY_INT8; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef int16 T;
+  static int Dtype() { return NPY_INT16; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef int T;
+  static int Dtype() { return NPY_INT; }
+};
+
+template <>
+struct TypeDescriptor {  // NOLINT
+  typedef long T;              // NOLINT
+  static int Dtype() { return NPY_LONG; }
+};
+
+template <>
+struct TypeDescriptor {  // NOLINT
+  typedef long long T;              // NOLINT
+  static int Dtype() { return NPY_LONGLONG; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef int8 T;
+  static int Dtype() { return NPY_BOOL; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef Eigen::half T;
+  static int Dtype() { return NPY_HALF; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef float T;
+  static int Dtype() { return NPY_FLOAT; }
+};
+
+template <>
+struct TypeDescriptor {
+  typedef double T;
+  static int Dtype() { return NPY_DOUBLE; }
+};
+
+template <>
+struct TypeDescriptor> {
+  typedef std::complex T;
+  static int Dtype() { return NPY_COMPLEX64; }
+};
+
+template <>
+struct TypeDescriptor> {
+  typedef std::complex T;
+  static int Dtype() { return NPY_COMPLEX128; }
+};
+
 // Performs a NumPy array cast from type 'From' to 'To'.
 template 
 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
              void* toarr) {
-  const From* from = reinterpret_cast(from_void);
-  To* to = reinterpret_cast(to_void);
+  const auto* from =
+      reinterpret_cast::T*>(from_void);
+  auto* to = reinterpret_cast::T*>(to_void);
   for (npy_intp i = 0; i < n; ++i) {
-    to[i] = static_cast(from[i]);
+    to[i] =
+        static_cast::T>(static_cast(from[i]));
   }
 }
 
@@ -504,7 +697,7 @@ void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
 // safely coerced to T.
 template 
 bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
-  if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16_,
+  if (PyArray_RegisterCastFunc(PyArray_DescrFromType(numpy_type), npy_bfloat16,
                                NPyCast) < 0) {
     return false;
   }
@@ -520,60 +713,591 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
 }
 
 template 
-void BinaryUFunc(char** args, const npy_intp* dimensions, const npy_intp* steps,
-                 void* data) {
-  const char* i0 = args[0];
-  const char* i1 = args[1];
-  char* o = args[2];
-  for (npy_intp k = 0; k < *dimensions; k++) {
-    InType x = *reinterpret_cast(i0);
-    InType y = *reinterpret_cast(i1);
-    *reinterpret_cast(o) = Functor()(x, y);
-    i0 += steps[0];
-    i1 += steps[1];
-    o += steps[2];
+struct UnaryUFunc {
+  static std::vector Types() {
+    return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype()};
   }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    char* o = args[1];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast::T*>(i0);
+      *reinterpret_cast::T*>(o) = Functor()(x);
+      i0 += steps[0];
+      o += steps[1];
+    }
+  }
+};
+
+template 
+struct UnaryUFunc2 {
+  static std::vector Types() {
+    return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(),
+            TypeDescriptor::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    char* o0 = args[1];
+    char* o1 = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast::T*>(i0);
+      std::tie(*reinterpret_cast::T*>(o0),
+               *reinterpret_cast::T*>(o1)) =
+          Functor()(x);
+      i0 += steps[0];
+      o0 += steps[1];
+      o1 += steps[2];
+    }
+  }
+};
+
+template 
+struct BinaryUFunc {
+  static std::vector Types() {
+    return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(),
+            TypeDescriptor::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast::T*>(i0);
+      auto y = *reinterpret_cast::T*>(i1);
+      *reinterpret_cast::T*>(o) =
+          Functor()(x, y);
+      i0 += steps[0];
+      i1 += steps[1];
+      o += steps[2];
+    }
+  }
+};
+
+template 
+struct BinaryUFunc2 {
+  static std::vector Types() {
+    return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(),
+            TypeDescriptor::Dtype()};
+  }
+  static void Call(char** args, const npy_intp* dimensions,
+                   const npy_intp* steps, void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o = args[2];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      auto x = *reinterpret_cast::T*>(i0);
+      auto y =
+          *reinterpret_cast::T*>(i1);
+      *reinterpret_cast::T*>(o) =
+          Functor()(x, y);
+      i0 += steps[0];
+      i1 += steps[1];
+      o += steps[2];
+    }
+  }
+};
+
+template 
+bool RegisterUFunc(PyObject* numpy, const char* name) {
+  std::vector types = UFunc::Types();
+  PyUFuncGenericFunction fn =
+      reinterpret_cast(UFunc::Call);
+  Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
+  if (!ufunc_obj) {
+    return false;
+  }
+  PyUFuncObject* ufunc = reinterpret_cast(ufunc_obj.get());
+  if (static_cast(types.size()) != ufunc->nargs) {
+    PyErr_Format(PyExc_AssertionError,
+                 "ufunc %s takes %d arguments, loop takes %lu", name,
+                 ufunc->nargs, types.size());
+    return false;
+  }
+  if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
+                                  const_cast(types.data()),
+                                  nullptr) < 0) {
+    return false;
+  }
+  return true;
 }
 
-// Numpy changed const-ness of PyUFuncGenericFunction, provide overload.
-template 
-void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
-                  void* data) {
-  BinaryUFunc(args, dimensions, steps, data);
-}
-template 
-void CompareUFunc(char** args, const npy_intp* dimensions,
-                  const npy_intp* steps, void* data) {
-  BinaryUFunc(args, dimensions, steps, data);
+namespace ufuncs {
+
+struct Add {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
+};
+struct Subtract {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
+};
+struct Multiply {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
+};
+struct TrueDivide {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
+};
+
+std::pair divmod(float a, float b) {
+  if (b == 0.0f) {
+    float nan = std::numeric_limits::quiet_NaN();
+    return {nan, nan};
+  }
+  float mod = std::fmod(a, b);
+  float div = (a - mod) / b;
+  if (mod != 0.0f) {
+    if ((b < 0.0f) != (mod < 0.0f)) {
+      mod += b;
+      div -= 1.0f;
+    }
+  } else {
+    mod = std::copysign(0.0f, b);
+  }
+
+  float floordiv;
+  if (div != 0.0f) {
+    floordiv = std::floor(div);
+    if (div - floordiv > 0.5f) {
+      floordiv += 1.0f;
+    }
+  } else {
+    floordiv = std::copysign(0.0f, a / b);
+  }
+  return {floordiv, mod};
 }
 
-struct Bfloat16EqFunctor {
+struct FloorDivide {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(divmod(static_cast(a), static_cast(b)).first);
+  }
+};
+struct Remainder {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(
+        divmod(static_cast(a), static_cast(b)).second);
+  }
+};
+struct DivmodUFunc {
+  static std::vector Types() {
+    return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
+  }
+  static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
+                   void* data) {
+    const char* i0 = args[0];
+    const char* i1 = args[1];
+    char* o0 = args[2];
+    char* o1 = args[3];
+    for (npy_intp k = 0; k < *dimensions; k++) {
+      bfloat16 x = *reinterpret_cast(i0);
+      bfloat16 y = *reinterpret_cast(i1);
+      float floordiv, mod;
+      std::tie(floordiv, mod) =
+          divmod(static_cast(x), static_cast(y));
+      *reinterpret_cast(o0) = bfloat16(floordiv);
+      *reinterpret_cast(o1) = bfloat16(mod);
+      i0 += steps[0];
+      i1 += steps[1];
+      o0 += steps[2];
+      o1 += steps[3];
+    }
+  }
+};
+struct Fmod {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::fmod(static_cast(a), static_cast(b)));
+  }
+};
+struct Negative {
+  bfloat16 operator()(bfloat16 a) { return -a; }
+};
+struct Positive {
+  bfloat16 operator()(bfloat16 a) { return a; }
+};
+struct Power {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::pow(static_cast(a), static_cast(b)));
+  }
+};
+struct Abs {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::abs(static_cast(a)));
+  }
+};
+struct Cbrt {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cbrt(static_cast(a)));
+  }
+};
+struct Ceil {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::ceil(static_cast(a)));
+  }
+};
+struct CopySign {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(
+        std::copysign(static_cast(a), static_cast(b)));
+  }
+};
+struct Exp {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::exp(static_cast(a)));
+  }
+};
+struct Exp2 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::exp2(static_cast(a)));
+  }
+};
+struct Expm1 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::expm1(static_cast(a)));
+  }
+};
+struct Floor {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::floor(static_cast(a)));
+  }
+};
+struct Frexp {
+  std::pair operator()(bfloat16 a) {
+    int exp;
+    float f = std::frexp(static_cast(a), &exp);
+    return {bfloat16(f), exp};
+  }
+};
+struct Heaviside {
+  bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
+    float x = static_cast(bx);
+    if (Eigen::numext::isnan(x)) {
+      return bx;
+    }
+    if (x < 0) {
+      return bfloat16(0.0f);
+    }
+    if (x > 0) {
+      return bfloat16(1.0f);
+    }
+    return h0;  // x == 0
+  }
+};
+struct Conjugate {
+  bfloat16 operator()(bfloat16 a) { return a; }
+};
+struct IsFinite {
+  bool operator()(bfloat16 a) { return std::isfinite(static_cast(a)); }
+};
+struct IsInf {
+  bool operator()(bfloat16 a) { return std::isinf(static_cast(a)); }
+};
+struct IsNan {
+  bool operator()(bfloat16 a) {
+    return Eigen::numext::isnan(static_cast(a));
+  }
+};
+struct Ldexp {
+  bfloat16 operator()(bfloat16 a, int exp) {
+    return bfloat16(std::ldexp(static_cast(a), exp));
+  }
+};
+struct Log {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log(static_cast(a)));
+  }
+};
+struct Log2 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log2(static_cast(a)));
+  }
+};
+struct Log10 {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log10(static_cast(a)));
+  }
+};
+struct Log1p {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::log1p(static_cast(a)));
+  }
+};
+struct LogAddExp {
+  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+    float x = static_cast(bx);
+    float y = static_cast(by);
+    if (x == y) {
+      // Handles infinities of the same sign.
+      return bfloat16(x + std::log(2.0f));
+    }
+    float out = std::numeric_limits::quiet_NaN();
+    if (x > y) {
+      out = x + std::log1p(std::exp(y - x));
+    } else if (x < y) {
+      out = y + std::log1p(std::exp(x - y));
+    }
+    return bfloat16(out);
+  }
+};
+struct LogAddExp2 {
+  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+    float x = static_cast(bx);
+    float y = static_cast(by);
+    if (x == y) {
+      // Handles infinities of the same sign.
+      return bfloat16(x + 1.0f);
+    }
+    float out = std::numeric_limits::quiet_NaN();
+    if (x > y) {
+      out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
+    } else if (x < y) {
+      out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
+    }
+    return bfloat16(out);
+  }
+};
+struct Modf {
+  std::pair operator()(bfloat16 a) {
+    float integral;
+    float f = std::modf(static_cast(a), &integral);
+    return {bfloat16(f), bfloat16(integral)};
+  }
+};
+
+struct Reciprocal {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(1.f / static_cast(a));
+  }
+};
+struct Rint {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::rint(static_cast(a)));
+  }
+};
+struct Sign {
+  bfloat16 operator()(bfloat16 a) {
+    float f(a);
+    if (f < 0) {
+      return bfloat16(-1);
+    }
+    if (f > 0) {
+      return bfloat16(1);
+    }
+    return a;
+  }
+};
+struct SignBit {
+  bool operator()(bfloat16 a) { return std::signbit(static_cast(a)); }
+};
+struct Sqrt {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sqrt(static_cast(a)));
+  }
+};
+struct Square {
+  bfloat16 operator()(bfloat16 a) {
+    float f(a);
+    return bfloat16(f * f);
+  }
+};
+struct Trunc {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::trunc(static_cast(a)));
+  }
+};
+
+// Trigonometric functions
+struct Sin {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sin(static_cast(a)));
+  }
+};
+struct Cos {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cos(static_cast(a)));
+  }
+};
+struct Tan {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::tan(static_cast(a)));
+  }
+};
+struct Arcsin {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::asin(static_cast(a)));
+  }
+};
+struct Arccos {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::acos(static_cast(a)));
+  }
+};
+struct Arctan {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::atan(static_cast(a)));
+  }
+};
+struct Arctan2 {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::atan2(static_cast(a), static_cast(b)));
+  }
+};
+struct Hypot {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    return bfloat16(std::hypot(static_cast(a), static_cast(b)));
+  }
+};
+struct Sinh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::sinh(static_cast(a)));
+  }
+};
+struct Cosh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::cosh(static_cast(a)));
+  }
+};
+struct Tanh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::tanh(static_cast(a)));
+  }
+};
+struct Arcsinh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::asinh(static_cast(a)));
+  }
+};
+struct Arccosh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::acosh(static_cast(a)));
+  }
+};
+struct Arctanh {
+  bfloat16 operator()(bfloat16 a) {
+    return bfloat16(std::atanh(static_cast(a)));
+  }
+};
+struct Deg2rad {
+  bfloat16 operator()(bfloat16 a) {
+    static constexpr float radians_per_degree = M_PI / 180.0f;
+    return bfloat16(static_cast(a) * radians_per_degree);
+  }
+};
+struct Rad2deg {
+  bfloat16 operator()(bfloat16 a) {
+    static constexpr float degrees_per_radian = 180.0f / M_PI;
+    return bfloat16(static_cast(a) * degrees_per_radian);
+  }
+};
+
+struct Eq {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
 };
-struct Bfloat16NeFunctor {
+struct Ne {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
 };
-struct Bfloat16LtFunctor {
+struct Lt {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
 };
-struct Bfloat16GtFunctor {
+struct Gt {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
 };
-struct Bfloat16LeFunctor {
+struct Le {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
 };
-struct Bfloat16GeFunctor {
+struct Ge {
   npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
 };
+struct Maximum {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fa) || fa > fb ? a : b;
+  }
+};
+struct Minimum {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fa) || fa < fb ? a : b;
+  }
+};
+struct Fmax {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fb) || fa > fb ? a : b;
+  }
+};
+struct Fmin {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    float fa(a), fb(b);
+    return Eigen::numext::isnan(fb) || fa < fb ? a : b;
+  }
+};
+
+struct LogicalNot {
+  npy_bool operator()(bfloat16 a) { return !a; }
+};
+struct LogicalAnd {
+  npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
+};
+struct LogicalOr {
+  npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
+};
+struct LogicalXor {
+  npy_bool operator()(bfloat16 a, bfloat16 b) {
+    return static_cast(a) ^ static_cast(b);
+  }
+};
+
+struct NextAfter {
+  bfloat16 operator()(bfloat16 from, bfloat16 to) {
+    uint16_t from_as_int, to_as_int;
+    const uint16_t sign_mask = 1 << 15;
+    float from_as_float(from), to_as_float(to);
+    memcpy(&from_as_int, &from, sizeof(bfloat16));
+    memcpy(&to_as_int, &to, sizeof(bfloat16));
+    if (Eigen::numext::isnan(from_as_float) ||
+        Eigen::numext::isnan(to_as_float)) {
+      return bfloat16(std::numeric_limits::quiet_NaN());
+    }
+    if (from_as_int == to_as_int) {
+      return to;
+    }
+    if (from_as_float == 0) {
+      if (to_as_float == 0) {
+        return to;
+      } else {
+        // Smallest subnormal signed like `to`.
+        uint16_t out_int = (to_as_int & sign_mask) | 1;
+        bfloat16 out;
+        memcpy(&out, &out_int, sizeof(bfloat16));
+        return out;
+      }
+    }
+    uint16_t from_sign = from_as_int & sign_mask;
+    uint16_t to_sign = to_as_int & sign_mask;
+    uint16_t from_abs = from_as_int & ~sign_mask;
+    uint16_t to_abs = to_as_int & ~sign_mask;
+    uint16_t magnitude_adjustment =
+        (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
+    uint16_t out_int = from_as_int + magnitude_adjustment;
+    bfloat16 out;
+    memcpy(&out, &out_int, sizeof(bfloat16));
+    return out;
+  }
+};
+
+// TODO(phawkins): implement spacing
+
+}  // namespace ufuncs
+
+}  // namespace
 
 // Initializes the module.
 bool Initialize() {
-  // It's critical to ImportNumpy and import umath
-  // to avoid crash in open source build.
   ImportNumpy();
   import_umath1(false);
 
-  Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
+  Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
   if (!numpy_str) {
     return false;
   }
@@ -582,7 +1306,6 @@ bool Initialize() {
     return false;
   }
 
-  // We hit a mysterious crash if we haven't initialized numpy before this:
   PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
 
   if (PyType_Ready(&PyBfloat16_Type) < 0) {
@@ -598,10 +1321,16 @@ bool Initialize() {
   NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
   NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
   NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
+  NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
+  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
+  NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
+  NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
 
   Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
-  npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
-  if (npy_bfloat16_ < 0) return false;
+  npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
+  if (npy_bfloat16 < 0) {
+    return false;
+  }
 
   // Support dtype(bfloat16)
   if (PyDict_SetItemString(PyBfloat16_Type.tp_dict, "dtype",
@@ -611,114 +1340,243 @@ bool Initialize() {
   }
 
   // Register casts
-
-  // We lie shamelessly and say that a cast from half to bfloat16 is safe.
-  // Numpy frequently uses the smallest legal representation type for small
-  // float constants (e.g., 1.0), which is often float16. Things break if these
-  // cannot be converted transparently to bfloat16.
-  if (!RegisterBfloat16Cast(NPY_HALF, /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast(NPY_HALF, /*cast_is_safe=*/false)) {
     return false;
   }
-
   if (!RegisterBfloat16Cast(NPY_FLOAT, /*cast_is_safe=*/true)) {
     return false;
   }
   if (!RegisterBfloat16Cast(NPY_DOUBLE, /*cast_is_safe=*/true)) {
     return false;
   }
-  if (!RegisterBfloat16Cast(NPY_INT32, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast(NPY_BOOL, /*cast_is_safe=*/false)) {
     return false;
   }
-  if (!RegisterBfloat16Cast(NPY_INT64, /*cast_is_safe=*/false)) {
+  if (!RegisterBfloat16Cast(NPY_UINT8, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_UINT16, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_UINT, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_ULONG,  // NOLINT
+                                           /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(  // NOLINT
+          NPY_ULONGLONG, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_UINT64, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_INT8, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_INT16, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_INT, /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(NPY_LONG,  // NOLINT
+                                  /*cast_is_safe=*/false)) {
+    return false;
+  }
+  if (!RegisterBfloat16Cast(  // NOLINT
+          NPY_LONGLONG, /*cast_is_safe=*/false)) {
     return false;
   }
   // Following the numpy convention. imag part is dropped when converting to
   // float.
-  if (!RegisterBfloat16Cast(NPY_COMPLEX64, /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast>(NPY_COMPLEX64,
+                                                 /*cast_is_safe=*/true)) {
     return false;
   }
-  if (!RegisterBfloat16Cast(NPY_COMPLEX128,
-                                        /*cast_is_safe=*/true)) {
+  if (!RegisterBfloat16Cast>(NPY_COMPLEX128,
+                                                  /*cast_is_safe=*/true)) {
     return false;
   }
 
-  // Register ufuncs
-  auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
-                            const std::array& types) {
-    Safe_PyObjectPtr ufunc_obj =
-        make_safe(PyObject_GetAttrString(numpy.get(), name));
-    if (!ufunc_obj) {
-      return false;
-    }
-    PyUFuncObject* ufunc = reinterpret_cast(ufunc_obj.get());
-    if (types.size() != ufunc->nargs) {
-      PyErr_Format(PyExc_AssertionError,
-                   "ufunc %s takes %d arguments, loop takes %lu", name,
-                   ufunc->nargs, types.size());
-      return false;
-    }
-    if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
-                                    const_cast(types.data()),
-                                    nullptr) < 0) {
-      return false;
-    }
-    return true;
-  };
+  bool ok =
+      RegisterUFunc>(numpy.get(),
+                                                                  "add") &&
+      RegisterUFunc>(
+          numpy.get(), "subtract") &&
+      RegisterUFunc>(
+          numpy.get(), "multiply") &&
+      RegisterUFunc>(
+          numpy.get(), "divide") &&
+      RegisterUFunc>(
+          numpy.get(), "logaddexp") &&
+      RegisterUFunc>(
+          numpy.get(), "logaddexp2") &&
+      RegisterUFunc>(
+          numpy.get(), "negative") &&
+      RegisterUFunc>(
+          numpy.get(), "positive") &&
+      RegisterUFunc>(
+          numpy.get(), "true_divide") &&
+      RegisterUFunc>(
+          numpy.get(), "floor_divide") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "power") &&
+      RegisterUFunc>(
+          numpy.get(), "remainder") &&
+      RegisterUFunc>(
+          numpy.get(), "mod") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "fmod") &&
+      RegisterUFunc(numpy.get(), "divmod") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "absolute") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "fabs") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "rint") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "sign") &&
+      RegisterUFunc>(
+          numpy.get(), "heaviside") &&
+      RegisterUFunc>(
+          numpy.get(), "conjugate") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "exp") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "exp2") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "expm1") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "log") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "log2") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "log10") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "log1p") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "sqrt") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "square") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "cbrt") &&
+      RegisterUFunc>(
+          numpy.get(), "reciprocal") &&
 
-  // Comparisons
-  const std::array compare_types = {
-      {npy_bfloat16_, npy_bfloat16_, NPY_BOOL}};
+      // Trigonometric functions
+      RegisterUFunc>(numpy.get(),
+                                                                 "sin") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "cos") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "tan") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "arcsin") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "arccos") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "arctan") &&
+      RegisterUFunc>(
+          numpy.get(), "arctan2") &&
+      RegisterUFunc>(numpy.get(),
+                                                                    "hypot") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "sinh") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "cosh") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "tanh") &&
+      RegisterUFunc>(
+          numpy.get(), "arcsinh") &&
+      RegisterUFunc>(
+          numpy.get(), "arccosh") &&
+      RegisterUFunc>(
+          numpy.get(), "arctanh") &&
+      RegisterUFunc>(
+          numpy.get(), "deg2rad") &&
+      RegisterUFunc>(
+          numpy.get(), "rad2deg") &&
 
-  if (!register_ufunc("equal", CompareUFunc,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("not_equal", CompareUFunc,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("less", CompareUFunc, compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("greater", CompareUFunc,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("less_equal", CompareUFunc,
-                      compare_types)) {
-    return false;
-  }
-  if (!register_ufunc("greater_equal", CompareUFunc,
-                      compare_types)) {
-    return false;
-  }
-  return true;
+      // Comparison functions
+      RegisterUFunc>(numpy.get(),
+                                                             "equal") &&
+      RegisterUFunc>(numpy.get(),
+                                                             "not_equal") &&
+      RegisterUFunc>(numpy.get(),
+                                                             "less") &&
+      RegisterUFunc>(numpy.get(),
+                                                             "greater") &&
+      RegisterUFunc>(numpy.get(),
+                                                             "less_equal") &&
+      RegisterUFunc>(numpy.get(),
+                                                             "greater_equal") &&
+      RegisterUFunc>(
+          numpy.get(), "maximum") &&
+      RegisterUFunc>(
+          numpy.get(), "minimum") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "fmax") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "fmin") &&
+      RegisterUFunc>(
+          numpy.get(), "logical_and") &&
+      RegisterUFunc>(
+          numpy.get(), "logical_or") &&
+      RegisterUFunc>(
+          numpy.get(), "logical_xor") &&
+      RegisterUFunc>(
+          numpy.get(), "logical_not") &&
+
+      // Floating point functions
+      RegisterUFunc>(numpy.get(),
+                                                                  "isfinite") &&
+      RegisterUFunc>(numpy.get(),
+                                                               "isinf") &&
+      RegisterUFunc>(numpy.get(),
+                                                               "isnan") &&
+      RegisterUFunc>(numpy.get(),
+                                                                 "signbit") &&
+      RegisterUFunc>(
+          numpy.get(), "copysign") &&
+      RegisterUFunc>(
+          numpy.get(), "modf") &&
+      RegisterUFunc>(
+          numpy.get(), "ldexp") &&
+      RegisterUFunc>(
+          numpy.get(), "frexp") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "floor") &&
+      RegisterUFunc>(numpy.get(),
+                                                                  "ceil") &&
+      RegisterUFunc>(numpy.get(),
+                                                                   "trunc") &&
+      RegisterUFunc>(
+          numpy.get(), "nextafter");
+
+  return ok;
 }
 
-}  // namespace
-
-void RegisterNumpyBfloat16() {
-  if (npy_bfloat16_ >= 0) {
+bool RegisterNumpyBfloat16() {
+  if (npy_bfloat16 >= 0) {
     // Already initialized.
-    return;
+    return true;
   }
   if (!Initialize()) {
     if (!PyErr_Occurred()) {
       PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
     }
     PyErr_Print();
+    return false;
   }
+  return true;
 }
 
-PyObject* Bfloat16PyType() {
-  CHECK(PyBfloat16_Type.tp_base != nullptr);
-  Py_INCREF(&PyBfloat16_Type);
+PyObject* Bfloat16Dtype() {
   return reinterpret_cast(&PyBfloat16_Type);
 }
 
-int Bfloat16NumpyType() {
-  CHECK_GE(npy_bfloat16_, 0);
-  return npy_bfloat16_;
-}
+int Bfloat16NumpyType() { return npy_bfloat16; }
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/bfloat16.h b/tensorflow/python/lib/core/bfloat16.h
index a609928ba90..e40207b5f8a 100644
--- a/tensorflow/python/lib/core/bfloat16.h
+++ b/tensorflow/python/lib/core/bfloat16.h
@@ -20,11 +20,11 @@ limitations under the License.
 
 namespace tensorflow {
 
-// Register the bfloat16 numpy type.
-void RegisterNumpyBfloat16();
+// Register the bfloat16 numpy type. Returns true on success.
+bool RegisterNumpyBfloat16();
 
-// Returns the PyObject for the bfloat16 type.
-PyObject* Bfloat16PyType();
+// Returns a pointer to the bfloat16 dtype object.
+PyObject* Bfloat16Dtype();
 
 // Returns the id number of the bfloat16 numpy type.
 int Bfloat16NumpyType();
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index f19029911bf..9bd5375df81 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -12,54 +12,80 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Test cases for the bfloat16 Python type."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import copy
+import itertools
 import math
 
+from absl.testing import absltest
+from absl.testing import parameterized
+
 import numpy as np
 
 # pylint: disable=unused-import,g-bad-import-order
 from tensorflow.python import _pywrap_bfloat16
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
 
 bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
 
 
-def float_values():
-  """Returns values that should round trip exactly to float and back."""
-  epsilon = float.fromhex("1.0p-7")
-  return [
-      0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-      -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
-      float("inf"),
-      float("-inf"),
-      float("nan")
-  ]
+def numpy_assert_allclose(a, b, **kwargs):
+  a = a.astype(np.float32) if a.dtype == bfloat16 else a
+  b = b.astype(np.float32) if b.dtype == bfloat16 else b
+  return np.testing.assert_allclose(a, b, **kwargs)
 
 
-class Bfloat16Test(test.TestCase):
+epsilon = float.fromhex("1.0p-7")
 
-  def _assertFloatIdentical(self, v, w):
-    if math.isnan(v):
-      self.assertTrue(math.isnan(w))
-    else:
-      self.assertEqual(v, w)
+# Values that should round trip exactly to float and back.
+FLOAT_VALUES = [
+    0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
+    -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
+    float("inf"),
+    float("-inf"),
+    float("nan")
+]
+
+
+class Bfloat16Test(parameterized.TestCase):
+  """Tests the non-numpy Python methods of the bfloat16 type."""
 
   def testRoundTripToFloat(self):
-    for v in float_values():
-      self._assertFloatIdentical(v, float(bfloat16(v)))
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(v, float(bfloat16(v)))
+
+  def testRoundTripNumpyTypes(self):
+    for dtype in [np.float16, np.float32, np.float64]:
+      np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
+      np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
+      np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
+      np.testing.assert_equal(
+          np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
 
   def testRoundTripToInt(self):
     for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
       self.assertEqual(v, int(bfloat16(v)))
 
+  # pylint: disable=g-complex-comprehension
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + dtype.__name__,
+      "dtype": dtype
+  } for dtype in [bfloat16, np.float16, np.float32, np.float64]))
+  def testRoundTripToNumpy(self, dtype):
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(v, bfloat16(dtype(v)))
+      np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
+      np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
+    if dtype != bfloat16:
+      np.testing.assert_equal(
+          np.array(FLOAT_VALUES, dtype),
+          bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
+
   def testStr(self):
     self.assertEqual("0", str(bfloat16(0.0)))
     self.assertEqual("1", str(bfloat16(1.0)))
@@ -70,14 +96,13 @@ class Bfloat16Test(test.TestCase):
     self.assertEqual("nan", str(bfloat16(float("nan"))))
 
   def testRepr(self):
-    self.assertEqual("bfloat16(0)", repr(bfloat16(0)))
-    self.assertEqual("bfloat16(1)", repr(bfloat16(1)))
-    self.assertEqual("bfloat16(-3.5)", repr(bfloat16(-3.5)))
-    self.assertEqual("bfloat16(0.0078125)",
-                     repr(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("bfloat16(inf)", repr(bfloat16(float("inf"))))
-    self.assertEqual("bfloat16(-inf)", repr(bfloat16(float("-inf"))))
-    self.assertEqual("bfloat16(nan)", repr(bfloat16(float("nan"))))
+    self.assertEqual("0", repr(bfloat16(0)))
+    self.assertEqual("1", repr(bfloat16(1)))
+    self.assertEqual("-3.5", repr(bfloat16(-3.5)))
+    self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
+    self.assertEqual("inf", repr(bfloat16(float("inf"))))
+    self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
+    self.assertEqual("nan", repr(bfloat16(float("nan"))))
 
   def testHash(self):
     self.assertEqual(0, hash(bfloat16(0.0)))
@@ -86,115 +111,166 @@ class Bfloat16Test(test.TestCase):
 
   # Tests for Python operations
   def testNegate(self):
-    for v in float_values():
-      self._assertFloatIdentical(-v, float(-bfloat16(v)))
+    for v in FLOAT_VALUES:
+      np.testing.assert_equal(-v, float(-bfloat16(v)))
 
   def testAdd(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) + bfloat16(0)))
-    self._assertFloatIdentical(1, float(bfloat16(1) + bfloat16(0)))
-    self._assertFloatIdentical(0, float(bfloat16(1) + bfloat16(-1)))
-    self._assertFloatIdentical(5.5, float(bfloat16(2) + bfloat16(3.5)))
-    self._assertFloatIdentical(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("inf")) + bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("-inf")) + bfloat16(-2.25)))
+    np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
+    np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
+    np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
+    np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
+    np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
 
+    # Test type promotion against Numpy scalar values.
+    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
+    self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
+    self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
+    self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
+    self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
+    self.assertEqual(np.float32,
+                     type(bfloat16(3.5) + np.array(2.25, np.float32)))
+    self.assertEqual(np.float32,
+                     type(np.array(3.5, np.float32) + bfloat16(2.25)))
+
   def testSub(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) - bfloat16(0)))
-    self._assertFloatIdentical(1, float(bfloat16(1) - bfloat16(0)))
-    self._assertFloatIdentical(2, float(bfloat16(1) - bfloat16(-1)))
-    self._assertFloatIdentical(-1.5, float(bfloat16(2) - bfloat16(3.5)))
-    self._assertFloatIdentical(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(-2.25) - bfloat16(float("inf"))))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(-2.25) - bfloat16(float("-inf"))))
+    np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
+    np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
+    np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
+    np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
+    np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
     self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
 
   def testMul(self):
-    self._assertFloatIdentical(0, float(bfloat16(0) * bfloat16(0)))
-    self._assertFloatIdentical(0, float(bfloat16(1) * bfloat16(0)))
-    self._assertFloatIdentical(-1, float(bfloat16(1) * bfloat16(-1)))
-    self._assertFloatIdentical(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("inf")) * bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("-inf")) * bfloat16(-2.25)))
+    np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
+    np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
+    np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
+    np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
 
   def testDiv(self):
     self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
-    self._assertFloatIdentical(float("inf"), float(bfloat16(1) / bfloat16(0)))
-    self._assertFloatIdentical(-1, float(bfloat16(1) / bfloat16(-1)))
-    self._assertFloatIdentical(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
-    self._assertFloatIdentical(float("-inf"),
-                               float(bfloat16(float("inf")) / bfloat16(-2.25)))
-    self._assertFloatIdentical(float("inf"),
-                               float(bfloat16(float("-inf")) / bfloat16(-2.25)))
+    np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
+    np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
+    np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
+    np.testing.assert_equal(
+        float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
+    np.testing.assert_equal(
+        float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
     self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
 
   def testLess(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
 
   def testLessEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
 
   def testGreater(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
 
   def testGreaterEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
 
   def testEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
 
   def testNotEqual(self):
-    for v in float_values():
-      for w in float_values():
+    for v in FLOAT_VALUES:
+      for w in FLOAT_VALUES:
         self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
 
   def testNan(self):
     a = np.isnan(bfloat16(float("nan")))
     self.assertTrue(a)
-    np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
+    numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
 
-    a = np.array(
-        [bfloat16(1.34375),
-         bfloat16(1.4375),
-         bfloat16(float("nan"))],
-        dtype=dtypes.bfloat16.as_numpy_dtype)
+    a = np.array([bfloat16(1.34375),
+                  bfloat16(1.4375),
+                  bfloat16(float("nan"))],
+                 dtype=bfloat16)
     b = np.array(
         [bfloat16(1.3359375),
          bfloat16(1.4375),
          bfloat16(float("nan"))],
-        dtype=dtypes.bfloat16.as_numpy_dtype)
-    np.testing.assert_allclose(
+        dtype=bfloat16)
+    numpy_assert_allclose(
         a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
 
+  def testSort(self):
+    values_to_sort = np.float32(FLOAT_VALUES)
+    sorted_f32 = np.sort(values_to_sort)
+    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
+    np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
 
-class Bfloat16NumPyTest(test.TestCase):
+
+BinaryOp = collections.namedtuple("BinaryOp", ["op"])
+
+UNARY_UFUNCS = [
+    np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
+    np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
+    np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
+    np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
+    np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
+]
+
+BINARY_UFUNCS = [
+    np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
+    np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
+    np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
+]
+
+BINARY_PREDICATE_UFUNCS = [
+    np.equal, np.not_equal, np.less, np.greater, np.less_equal,
+    np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
+]
+
+
+class Bfloat16NumPyTest(parameterized.TestCase):
+  """Tests the NumPy integration of the bfloat16 type."""
 
   def testDtype(self):
     self.assertEqual(bfloat16, np.dtype(bfloat16))
 
+  def testDeepCopyDoesNotAlterHash(self):
+    # For context, see https://github.com/google/jax/issues/4651. If the hash
+    # value of the type descriptor is not initialized correctly, a deep copy
+    # can change the type hash.
+    dtype = np.dtype(bfloat16)
+    h = hash(dtype)
+    _ = copy.deepcopy(dtype)
+    self.assertEqual(h, hash(dtype))
+
   def testArray(self):
     x = np.array([[1, 2, 3]], dtype=bfloat16)
     self.assertEqual(bfloat16, x.dtype)
-    self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
-    self.assertAllEqual(x, x)
-    self.assertAllClose(x, x)
+    self.assertEqual("[[1 2 3]]", str(x))
+    np.testing.assert_equal(x, x)
+    numpy_assert_allclose(x, x)
     self.assertTrue((x == x).all())
 
   def testComparisons(self):
@@ -202,12 +278,12 @@ class Bfloat16NumPyTest(test.TestCase):
     bx = x.astype(bfloat16)
     y = np.array([82432, 7, 0], dtype=np.float32)
     by = y.astype(bfloat16)
-    self.assertAllEqual(x == y, bx == by)
-    self.assertAllEqual(x != y, bx != by)
-    self.assertAllEqual(x < y, bx < by)
-    self.assertAllEqual(x > y, bx > by)
-    self.assertAllEqual(x <= y, bx <= by)
-    self.assertAllEqual(x >= y, bx >= by)
+    np.testing.assert_equal(x == y, bx == by)
+    np.testing.assert_equal(x != y, bx != by)
+    np.testing.assert_equal(x < y, bx < by)
+    np.testing.assert_equal(x > y, bx > by)
+    np.testing.assert_equal(x <= y, bx <= by)
+    np.testing.assert_equal(x >= y, bx >= by)
 
   def testEqual2(self):
     a = np.array([401408], bfloat16)
@@ -216,8 +292,10 @@ class Bfloat16NumPyTest(test.TestCase):
 
   def testCasts(self):
     for dtype in [
-        np.float16, np.float32, np.float64, np.int32, np.int64,
-        np.complex64, np.complex128]:
+        np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
+        np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
+        np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
+    ]:
       x = np.array([[1, 2, 3]], dtype=dtype)
       y = x.astype(bfloat16)
       z = y.astype(dtype)
@@ -231,44 +309,133 @@ class Bfloat16NumPyTest(test.TestCase):
       x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
       y_np = x.astype(np.float32)
       y_tf = x.astype(bfloat16)
-      self.assertAllClose(y_np, y_tf, atol=2e-2)
+      numpy_assert_allclose(y_np, y_tf, atol=2e-2)
 
       z_np = y_np.astype(dtype)
       z_tf = y_tf.astype(dtype)
-      self.assertAllClose(z_np, z_tf, atol=2e-2)
-
-  def testAdd(self):
-    x = np.array([[1, 2, 3]], dtype=bfloat16)
-    y = np.array([[4, 5, 6]], dtype=bfloat16)
-    self.assertAllClose(np.array([[5, 7, 9]]), x + y)
-
-  def testLogSumExp(self):
-    x = np.array([[1, 2, 3]], dtype=np.float32)
-    y = np.array([[4, 5, 6]], dtype=np.float32)
-    self.assertAllClose(np.logaddexp(x, y),
-                        np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
-                        atol=2e-2)
+      numpy_assert_allclose(z_np, z_tf, atol=2e-2)
 
   def testArange(self):
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(100, dtype=np.float32).astype(bfloat16),
         np.arange(100, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
         np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
         np.arange(-0., -7., -0.25, dtype=bfloat16))
-    self.assertAllEqual(
+    np.testing.assert_equal(
         np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
         np.arange(-16384., 16384., 64., dtype=bfloat16))
 
-  def testSort(self):
-    values_to_sort = np.float32(float_values())
-    sorted_f32 = np.sort(values_to_sort)
-    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
-    self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
+  # pylint: disable=g-complex-comprehension
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in UNARY_UFUNCS))
+  def testUnaryUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7, 10).astype(bfloat16)
+    numpy_assert_allclose(
+        op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in BINARY_UFUNCS))
+  def testBinaryUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7, 10).astype(bfloat16)
+    y = rng.randn(4, 1, 7, 10).astype(bfloat16)
+    numpy_assert_allclose(
+        op(x, y).astype(np.float32),
+        op(x.astype(np.float32), y.astype(np.float32)),
+        rtol=1e-2)
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in BINARY_PREDICATE_UFUNCS))
+  def testBinaryPredicateUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randn(4, 1, 7).astype(bfloat16)
+    np.testing.assert_equal(
+        op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
+
+  @parameterized.named_parameters(({
+      "testcase_name": "_" + op.__name__,
+      "op": op
+  } for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
+  def testPredicateUfunc(self, op):
+    rng = np.random.RandomState(seed=42)
+    shape = (3, 7, 10)
+    posinf_flips = rng.rand(*shape) < 0.1
+    neginf_flips = rng.rand(*shape) < 0.1
+    nan_flips = rng.rand(*shape) < 0.1
+    vals = rng.randn(*shape)
+    vals = np.where(posinf_flips, np.inf, vals)
+    vals = np.where(neginf_flips, -np.inf, vals)
+    vals = np.where(nan_flips, np.nan, vals)
+    vals = vals.astype(bfloat16)
+    np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
+
+  def testDivmod(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randn(4, 1, 7).astype(bfloat16)
+    o1, o2 = np.divmod(x, y)
+    e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
+    numpy_assert_allclose(o1, e1, rtol=1e-2)
+    numpy_assert_allclose(o2, e2, rtol=1e-2)
+
+  def testModf(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    o1, o2 = np.modf(x)
+    e1, e2 = np.modf(x.astype(np.float32))
+    numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
+    numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
+
+  def testLdexp(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    y = rng.randint(-50, 50, (1, 7))
+    numpy_assert_allclose(
+        np.ldexp(x, y).astype(np.float32),
+        np.ldexp(x.astype(np.float32), y),
+        rtol=1e-2,
+        atol=1e-6)
+
+  def testFrexp(self):
+    rng = np.random.RandomState(seed=42)
+    x = rng.randn(3, 7).astype(bfloat16)
+    mant1, exp1 = np.frexp(x)
+    mant2, exp2 = np.frexp(x.astype(np.float32))
+    np.testing.assert_equal(exp1, exp2)
+    numpy_assert_allclose(mant1, mant2, rtol=1e-2)
+
+  def testNextAfter(self):
+    one = np.array(1., dtype=bfloat16)
+    two = np.array(2., dtype=bfloat16)
+    zero = np.array(0., dtype=bfloat16)
+    nan = np.array(np.nan, dtype=bfloat16)
+    np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
+    np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
+    np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
+    np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
+    np.testing.assert_equal(np.nextafter(one, one), one)
+    smallest_denormal = float.fromhex("1.0p-133")
+    np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
+    np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
+    for a, b in itertools.permutations([0., -0., nan], 2):
+      np.testing.assert_equal(
+          np.nextafter(
+              np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
+          np.nextafter(
+              np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
 
 
 if __name__ == "__main__":
-  test.main()
+  absltest.main()
diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc
index eb346af896a..741468bccd9 100644
--- a/tensorflow/python/lib/core/bfloat16_wrapper.cc
+++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc
@@ -20,5 +20,5 @@ PYBIND11_MODULE(_pywrap_bfloat16, m) {
   tensorflow::RegisterNumpyBfloat16();
 
   m.def("TF_bfloat16_type",
-        [] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
+        [] { return pybind11::handle(tensorflow::Bfloat16Dtype()); });
 }
diff --git a/tensorflow/python/ops/handle_data_util.py b/tensorflow/python/ops/handle_data_util.py
index d83bea3cb18..4f17cf4c667 100644
--- a/tensorflow/python/ops/handle_data_util.py
+++ b/tensorflow/python/ops/handle_data_util.py
@@ -19,20 +19,11 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.client import pywrap_tf_session
-from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.util import compat
 
 
-def get_resource_handle_data(graph_op):
-  assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck
-
-  handle_data = pywrap_tf_session.GetHandleShapeAndType(
-      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
-
-  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
-      compat.as_bytes(handle_data))
+get_resource_handle_data = ops.get_resource_handle_data
 
 
 def copy_handle_data(source_t, target_t):
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bbd583162a2..e7467623fe3 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -4455,7 +4455,7 @@ def image_gradients(image):
     image = tf.reshape(tf.range(IMAGE_HEIGHT * IMAGE_WIDTH * CHANNELS,
       delta=1, dtype=tf.float32),
       shape=(BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
-    dx, dy = tf.image.image_gradients(image)
+    dy, dx = tf.image.image_gradients(image)
     print(image[0, :,:,0])
     tf.Tensor(
       [[ 0.  1.  2.  3.  4.]
@@ -4463,14 +4463,14 @@ def image_gradients(image):
       [10. 11. 12. 13. 14.]
       [15. 16. 17. 18. 19.]
       [20. 21. 22. 23. 24.]], shape=(5, 5), dtype=float32)
-    print(dx[0, :,:,0])
+    print(dy[0, :,:,0])
     tf.Tensor(
       [[5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [5. 5. 5. 5. 5.]
       [0. 0. 0. 0. 0.]], shape=(5, 5), dtype=float32)
-    print(dy[0, :,:,0])
+    print(dx[0, :,:,0])
     tf.Tensor(
       [[1. 1. 1. 1. 0.]
       [1. 1. 1. 1. 0.]
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index c2fad724541..3f327cc6b45 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -798,6 +798,39 @@ class LoadTest(test.TestCase, parameterized.TestCase):
     self.assertIsNotNone(imported_gradient)
     self.assertAllClose(imported_gradient, 2.)
 
+  def test_nested_fn_backprop(self, cycles):
+    weight = variables.Variable(2., trainable=True)
+
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
+    def g(x):
+      weight.read_value()  # Just get the tape to watch the variable
+      handle = array_ops.identity(weight.handle)
+      @def_function.function
+      def launder_var_handle():
+        return array_ops.identity(handle)
+      return x + resource_variable_ops.read_variable_op(
+          launder_var_handle(), dtypes.float32)
+
+    root = tracking.AutoTrackable()
+    root.weight = weight
+    root.g = g
+    imported = cycle(root, cycles)
+    def get_gradient(obj, persistent):
+      with backprop.GradientTape(persistent=persistent) as t:
+        x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
+        y = obj.g(x)
+        self.assertAllClose(y, obj.weight + x)
+        loss = math_ops.reduce_sum(y)
+        return t.gradient(loss, obj.weight)
+
+    imported_gradient = get_gradient(imported, persistent=False)
+    original_gradient = get_gradient(root, persistent=False)
+    self.assertIsNotNone(original_gradient)
+    self.assertAllClose(original_gradient, 6.)
+    self.assertIsNotNone(imported_gradient)
+    self.assertAllClose(imported_gradient, 6.)
+
   def test_restored_func_with_captured_var_backprop_float32(self, cycles):
     self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32)
 
@@ -2064,6 +2097,7 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
   # allocations at a lower level.
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def test_functions_cleaned(self):
+    self.skipTest("TODO(b/175152958): The test is leaking function definitions")
     if sys.version_info.major < 3:
       self.skipTest("Not working in Python 2")
     root = module.Module()
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
index 32e5d541471..75235c3610c 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
@@ -123,11 +123,6 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
     self.feature_friends_row_lengths = [1, 3, 1, 3]
     self.resolver = None
 
-  def tearDown(self):
-    if self.resolver:
-      tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingCorrectness, self).tearDown()
-
   def _get_strategy(self):
     self.resolver = tpu_cluster_resolver.TPUClusterResolver(
         tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py
index d5f9e6446a6..5843bb63fad 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py
@@ -99,10 +99,6 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase):
     self.cpu_mid_level = self.build_mid_level(
         self.second_mid_level_contents, self.cpu_mid_level_optimizer)
 
-  def tearDown(self):
-    tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingCheckpointTest, self).tearDown()
-
   def test_checkpoint_save_retrieves(self):
     # Ensure that the variables from the first model are loaded.
     self.first_mid_level._load_variables()
@@ -401,11 +397,6 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
     self.feature_friends_row_lengths = [1, 3, 1, 3]
     self.resolver = None
 
-  def tearDown(self):
-    if self.resolver:
-      tpu_strategy_util.shutdown_tpu_system(self.resolver)
-    super(TPUEmbeddingTest, self).tearDown()
-
   def test_tables_with_same_name(self):
     with self.assertRaisesRegex(
         ValueError, 'Multiple tables with name table found.'):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index db3ad27310c..21a61e47d50 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -438,15 +438,62 @@ def assert_same_structure(nest1, nest2, check_types=True,
                           expand_composites=False):
   """Asserts that two structures are nested in the same way.
 
-  Note that namedtuples with identical name and fields are always considered
-  to have the same shallow structure (even with `check_types=True`).
-  For instance, this code will print `True`:
+  Note the method does not check the types of data inside the structures.
 
-  ```python
-  def nt(a, b):
-    return collections.namedtuple('foo', 'a b')(a, b)
-  print(assert_same_structure(nt(0, 1), nt(2, 3)))
-  ```
+  Examples:
+
+  * These scalar vs. scalar comparisons will pass:
+
+    >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
+    >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
+
+  * These sequence vs. sequence comparisons will pass:
+
+    >>> structure1 = (((1, 2), 3), 4, (5, 6))
+    >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+    >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
+    >>> tf.nest.assert_same_structure(structure1, structure2)
+    >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
+
+    >>> import collections
+    >>> tf.nest.assert_same_structure(
+    ...     collections.namedtuple("bar", "a b")(1, 2),
+    ...     collections.namedtuple("foo", "a b")(2, 3),
+    ...     check_types=False)
+
+    >>> tf.nest.assert_same_structure(
+    ...     collections.namedtuple("bar", "a b")(1, 2),
+    ...     { "a": 1, "b": 2 },
+    ...     check_types=False)
+
+    >>> tf.nest.assert_same_structure(
+    ...     { "a": 1, "b": 2, "c": 3 },
+    ...     { "c": 6, "b": 5, "a": 4 })
+
+    >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
+    ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
+    ...       row_splits=[0, 4, 4, 7, 8, 8])
+    >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
+    ...       values=[3, 1, 4],
+    ...       row_splits=[0, 3])
+    >>> tf.nest.assert_same_structure(
+    ...       ragged_tensor1,
+    ...       ragged_tensor2,
+    ...       expand_composites=True)
+
+  * These examples will raise exceptions:
+
+    >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
+    Traceback (most recent call last):
+    ...
+    ValueError: The two structures don't have the same nested structure
+
+    >>> tf.nest.assert_same_structure(
+    ...       collections.namedtuple('bar', 'a b')(1, 2),
+    ...       collections.namedtuple('foo', 'a b')(2, 3))
+    Traceback (most recent call last):
+    ...
+    TypeError: The two structures don't have the same nested structure
 
   Args:
     nest1: an arbitrarily nested structure.
diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc
index d549896a889..452aa68f3e3 100644
--- a/tensorflow/python/util/tf_stack.cc
+++ b/tensorflow/python/util/tf_stack.cc
@@ -19,7 +19,7 @@ limitations under the License.
 // We store the retrieved stack trace within the Node object directly. Then
 // whenever the graph is instantiated/copies, we copy the stack trace with it.
 // Since the graph instantiation goes through the protobuf roundtrip, we store
-// the original Graph with stack traces attached in FunctionLibraryDefinition.
+// the original stack traces mapping attached in FunctionLibraryDefinition.
 
 #include 
 #include 
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index 0ee227d51f2..839950b1021 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -598,6 +598,16 @@ cc_library(
     ]),
 )
 
+cc_library(
+    name = "cuda_asm_compiler",
+    srcs = if_cuda_is_configured(["cuda_asm_compiler.cc"]),
+    deps = if_cuda_is_configured([
+        "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/stream_executor/gpu:asm_compiler",
+        "//tensorflow/stream_executor/gpu:gpu_driver_header",
+    ]),
+)
+
 cc_library(
     name = "cuda_gpu_executor",
     srcs = if_cuda_is_configured(["cuda_gpu_executor.cc"]),
@@ -611,6 +621,7 @@ cc_library(
         ":cuda_platform_id",
         ":cuda_stream",
         ":cuda_timer",
+        ":cuda_asm_compiler",
         "@com_google_absl//absl/strings",
         "//tensorflow/stream_executor:event",
         "//tensorflow/stream_executor:plugin_registry",
diff --git a/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
new file mode 100644
index 00000000000..f92d3c487d0
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
+
+namespace stream_executor {
+
+#define RETURN_IF_CUDA_ERROR(expr)                                            \
+  do {                                                                        \
+    CUresult _status = expr;                                                  \
+    if (!SE_PREDICT_TRUE(_status == CUDA_SUCCESS)) {                          \
+      const char* error_string;                                               \
+      cuGetErrorString(_status, &error_string);                               \
+      std::ostringstream oss;                                                 \
+      oss << error_string << "\nin " << __FILE__ << "(" << __LINE__ << "): '" \
+          << #expr << "'";                                                    \
+      return port::Status(port::error::UNKNOWN, oss.str().c_str());           \
+    }                                                                         \
+  } while (false)
+
+port::StatusOr> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector images) {
+  gpu::ScopedActivateContext activation(context);
+
+  CUlinkState link_state;
+  RETURN_IF_CUDA_ERROR(cuLinkCreate(0, nullptr, nullptr, &link_state));
+  for (auto& image : images) {
+    RETURN_IF_CUDA_ERROR(cuLinkAddData(
+        link_state, CU_JIT_INPUT_CUBIN, static_cast(image.bytes.data()),
+        image.bytes.size(), "", 0, nullptr, nullptr));
+  }
+  void* cubin_out;
+  size_t cubin_size;
+  RETURN_IF_CUDA_ERROR(cuLinkComplete(link_state, &cubin_out, &cubin_size));
+  std::vector cubin(static_cast(cubin_out),
+                           static_cast(cubin_out) + cubin_size);
+  RETURN_IF_CUDA_ERROR(cuLinkDestroy(link_state));
+  return std::move(cubin);
+}
+
+}  // namespace stream_executor
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index c03eb0a57e9..e4e9914adf0 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1468,7 +1468,9 @@ class CudnnRnnSequenceTensorDescriptor
   static port::StatusOr Create(
       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
       cudnnDataType_t data_type) {
-    CHECK_GT(max_seq_length, 0);
+    if (max_seq_length <= 0) {
+      return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
+    }
     int dims[] = {batch_size, data_size, 1};
     int strides[] = {dims[1] * dims[2], dims[2], 1};
     TensorDescriptor tensor_desc = CreateTensorDescriptor();
@@ -1486,7 +1488,9 @@ class CudnnRnnSequenceTensorDescriptor
       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
       const absl::Span& seq_lengths, bool time_major,
       cudnnDataType_t data_type) {
-    CHECK_GT(max_seq_length, 0);
+    if (max_seq_length <= 0) {
+      return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
+    }
     int dims[] = {batch_size, data_size, 1};
     int strides[] = {dims[1] * dims[2], dims[2], 1};
     TensorDescriptor tensor_desc = CreateTensorDescriptor();
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h
index 1ac58aaddf3..388f919a3c3 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.h
+++ b/tensorflow/stream_executor/gpu/asm_compiler.h
@@ -24,6 +24,9 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform/port.h"
 
 namespace stream_executor {
+namespace gpu {
+class GpuContext;
+}
 
 // Compiles the given PTX string using ptxas and returns the resulting machine
 // code (i.e. a cubin) as a byte array. The generated cubin matches the compute
@@ -72,6 +75,11 @@ struct HsacoImage {
 port::StatusOr> BundleGpuAsm(
     std::vector images, const std::string rocm_root_dir);
 
+// Links multiple relocatable GPU images (e.g. results of ptxas -c) into a
+// single image.
+port::StatusOr> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector images);
+
 }  // namespace stream_executor
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index 88fd63d9693..d5476bcf054 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -100,6 +100,12 @@ tf_proto {
       label: LABEL_OPTIONAL
       type: TYPE_INT64
     }
+    field {
+      name: "use_tfrt"
+      number: 18
+      label: LABEL_OPTIONAL
+      type: TYPE_BOOL
+    }
     enum_type {
       name: "MlirBridgeRollout"
       value: {
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index a598071b970..bb3bcde40cd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -229,6 +229,12 @@ tf_proto {
         label: LABEL_OPTIONAL
         type: TYPE_INT64
       }
+      field {
+        name: "use_tfrt"
+        number: 18
+        label: LABEL_OPTIONAL
+        type: TYPE_BOOL
+      }
       enum_type {
         name: "MlirBridgeRollout"
         value: {
diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh
index bfd551f1772..aa7cbef7566 100755
--- a/tensorflow/tools/ci_build/builds/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh
@@ -57,6 +57,8 @@ function build_libtensorflow_tarball() {
     BAZEL_OPTS="${BAZEL_OPTS} --config=cuda --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain"
     export TF_NEED_ROCM=0
     export TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
+  else
+    BAZEL_OPTS="${BAZEL_OPTS} --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
   fi
   bazel clean --expunge
   yes "" | ./configure
diff --git a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
index 0f0f182a1bc..ce7789b3704 100755
--- a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
@@ -102,7 +102,7 @@ pip3 install --upgrade termcolor
 pip2 install keras_preprocessing==1.0.5 --no-deps
 pip3 install keras_preprocessing==1.0.5 --no-deps
 pip2 install --upgrade h5py==2.8.0
-pip3 install --upgrade h5py==3.1.0
+pip3 install --upgrade h5py==2.8.0
 
 # Estimator
 pip2 install tf-estimator-nightly --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index f9893f070d5..578967a67cf 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -134,7 +134,7 @@ pip3 install --upgrade termcolor
 pip2 install keras_preprocessing==1.1.0 --no-deps
 pip3 install keras_preprocessing==1.1.0 --no-deps
 pip2 install --upgrade h5py==2.8.0
-pip3 install --upgrade h5py==3.1.0
+pip3 install --upgrade h5py==2.8.0
 
 # Estimator
 pip2 install tf-estimator-nightly --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 9530c9fdf22..bb53fc91981 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -87,7 +87,7 @@ pip3.5 install --upgrade termcolor
 
 # Keras
 pip3.5 install keras_preprocessing==1.0.5
-pip3.5 install --upgrade h5py==3.1.0
+pip3.5 install --upgrade h5py==2.8.0
 
 # Estimator
 pip3.5 install tf-estimator-nightly==1.12.0.dev20181203 --no-deps
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index f130ab87dc2..bcf0d0b87ab 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -101,7 +101,7 @@ pip3 install --upgrade astor
 pip3 install --upgrade gast
 pip3 install --upgrade termcolor
 
-pip3 install --upgrade h5py==3.1.0
+pip3 install --upgrade h5py==2.8.0
 
 # Keras
 pip3 install keras_preprocessing==1.0.5
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index 7b2ba29de8c..4fd671cb6d4 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -31,7 +31,7 @@ DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/..)"
 ROOT_DIR="$(realpath ${SCRIPT_DIR}/../../../../)"
 
 DOCKER_IMAGE="tf-libtensorflow-cpu"
-DOCKER_FILE="Dockerfile.cpu"
+DOCKER_FILE="Dockerfile.rbe.ubuntu16.04-manylinux2010"
 DOCKER_BINARY="docker"
 if [ "${TF_NEED_CUDA}" == "1" ]; then
   DOCKER_IMAGE="tf-tensorflow-gpu"
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
index dd7e2a56711..3f36d2c3220 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh
@@ -27,7 +27,7 @@ python3.6 -m virtualenv tf_build_env --system-site-packages
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.6
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
index 2f73ad62e1e..b3cddd0c109 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh
@@ -22,11 +22,11 @@ install_bazelisk
 # Pick a more recent version of xcode
 export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer
 sudo xcode-select -s "${DEVELOPER_DIR}"
-python -m virtualenv tf_build_env --system-site-packages
+python3.7 -m virtualenv tf_build_env --system-site-packages
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.7
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
index 11b557adc96..70d742b6bf7 100644
--- a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
+++ b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh
@@ -23,11 +23,11 @@ install_bazelisk
 export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer
 export MACOSX_DEPLOYMENT_TARGET=10.10
 sudo xcode-select -s "${DEVELOPER_DIR}"
-python -m virtualenv tf_build_env --system-site-packages
+python3.8 -m virtualenv tf_build_env --system-site-packages
 source tf_build_env/bin/activate
 
 # Install macos pip dependencies
-install_macos_pip_deps sudo pip3.8
+install_macos_pip_deps virtualenv
 
 # Run configure.
 export TF_NEED_CUDA=0
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index d572ac18af9..11b7c2abcb4 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -126,7 +126,7 @@ function install_ubuntu_16_pip_deps {
   "${PIP_CMD}" install --user 'astunparse ~= 1.6.3'
   "${PIP_CMD}" install --user 'flatbuffers ~= 1.12.0'
   "${PIP_CMD}" install --user 'google_pasta ~= 0.2'
-  "${PIP_CMD}" install --user 'h5py ~= 3.1.0'
+  "${PIP_CMD}" install --user 'h5py ~= 2.10.0'
   "${PIP_CMD}" install --user 'keras_preprocessing ~= 1.1.2'
   "${PIP_CMD}" install --user 'numpy ~= 1.19.2'
   "${PIP_CMD}" install --user 'opt_einsum ~= 3.3.0'
@@ -188,7 +188,7 @@ function install_macos_pip_deps {
   ${PIP_CMD} install $USER_FLAG 'astunparse ~= 1.6.3'
   ${PIP_CMD} install $USER_FLAG 'flatbuffers ~= 1.12.0'
   ${PIP_CMD} install $USER_FLAG 'google_pasta ~= 0.2'
-  ${PIP_CMD} install $USER_FLAG 'h5py ~= 3.1.0'
+  ${PIP_CMD} install $USER_FLAG 'h5py ~= 2.10.0'
   ${PIP_CMD} install $USER_FLAG 'keras_preprocessing ~= 1.1.2'
   ${PIP_CMD} install $USER_FLAG 'numpy ~= 1.19.2'
   ${PIP_CMD} install $USER_FLAG 'opt_einsum ~= 3.3.0'
diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat
index dbe159a6776..f27ec3117ed 100644
--- a/tensorflow/tools/ci_build/release/common_win.bat
+++ b/tensorflow/tools/ci_build/release/common_win.bat
@@ -18,7 +18,7 @@ echo on
 @REM Set Environment Variables
 @REM
 IF NOT DEFINED PYTHON_DIRECTORY (
-  SET PYTHON_DIRECTORY=Python37
+  SET PYTHON_DIRECTORY=Python36
 )
 SET PY_EXE=C:\%PYTHON_DIRECTORY%\python.exe
 SET PATH=%PATH%;C:\%PYTHON_DIRECTORY%
@@ -32,7 +32,7 @@ SET PATH=%PATH%;C:\%PYTHON_DIRECTORY%
 %PY_EXE% -m pip install "astunparse ~= 1.6.3"
 %PY_EXE% -m pip install "flatbuffers ~= 1.12.0"
 %PY_EXE% -m pip install "google_pasta ~= 0.2"
-%PY_EXE% -m pip install "h5py ~= 3.1.0"
+%PY_EXE% -m pip install "h5py ~= 2.10.0"
 %PY_EXE% -m pip install "keras_preprocessing ~= 1.1.2"
 %PY_EXE% -m pip install "numpy ~= 1.19.2"
 %PY_EXE% -m pip install "opt_einsum ~= 3.3.0"
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index ca13f4370b7..1ceb0b33861 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -48,7 +48,7 @@ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
 
 [//tensorflow/python:bfloat16_lib] # bfloat16
 tensorflow::RegisterNumpyBfloat16
-tensorflow::Bfloat16PyType
+tensorflow::Bfloat16Dtype
 
 [//tensorflow/python:py_func_lib] # py_func
 tensorflow::InitializePyTrampoline
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index d84b08de7a1..613ce9f5bf3 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -79,7 +79,7 @@ REQUIRED_PACKAGES = [
     'astunparse ~= 1.6.3',
     'flatbuffers ~= 1.12.0',
     'google_pasta ~= 0.2',
-    'h5py ~= 3.1.0',
+    'h5py ~= 2.10.0',
     'keras_preprocessing ~= 1.1.2',
     'numpy ~= 1.19.2',
     'opt_einsum ~= 3.3.0',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 35eb42b3424..0314535a40d 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "ecaff13fc0bc1105ad910a72a5d0dcd164b35191"
-    LLVM_SHA256 = "d0178d6f6a23ce60752d11ee8b1d64784d8ce9625f03d76943b0e40a0043211a"
+    LLVM_COMMIT = "d553243fe4b5e1992c07aff7b54b16160a4d5e97"
+    LLVM_SHA256 = "46b06b63414c21d86d8a91e9011f07dd974e976bbda767af66ec77c7d764f091"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),