Merge remote-tracking branch 'upstream/master' into xtensa-fusion-f1
This commit is contained in:
commit
6b24408403
@ -61,6 +61,7 @@
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
* Added DEPTH_TO_SPACE support in Post training quantization.
|
||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||
only supports float32 input.
|
||||
* TFLite Supports SingatureDef:
|
||||
|
@ -145,6 +145,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
|
@ -155,6 +155,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
|
@ -684,7 +684,10 @@ tf_cc_test(
|
||||
name = "c_api_experimental_test",
|
||||
size = "medium",
|
||||
srcs = ["c_api_experimental_test.cc"],
|
||||
data = ["testdata/tf_record"],
|
||||
data = [
|
||||
"testdata/tf_record",
|
||||
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||
],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
@ -704,6 +707,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
||||
|
||||
namespace tensorflow {
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||
} // namespace tensorflow
|
||||
|
||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
||||
// Load a Pluggable Device library.
|
||||
// On success, returns the handle to library in result and return OK from the
|
||||
// function. Otherwise return nullptr in result and error Status from the
|
||||
// function.
|
||||
//
|
||||
// If `library_filename` has already been loaded, we return a cached handle.
|
||||
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||
// for the first time.
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"PluggableDevice plugin functionality is not supported on mobile");
|
||||
return nullptr;
|
||||
#else
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static std::unordered_map<std::string, void*>* loaded_libs =
|
||||
new std::unordered_map<std::string, void*>();
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
auto it = loaded_libs->find(library_filename);
|
||||
if (it != loaded_libs->end()) {
|
||||
lib_handle->lib_handle = it->second;
|
||||
} else {
|
||||
status->status =
|
||||
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return lib_handle;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
delete lib_handle;
|
||||
}
|
||||
|
@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
|
||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||
|
||||
// Load the library specified by library_filename and register the pluggable
|
||||
// device and related kernels present in that library. This function is not
|
||||
// supported on embedded on mobile and embedded platforms and will fail if
|
||||
// called.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The rules for determining the exact location of the
|
||||
// library are platform-specific and are not documented here.
|
||||
//
|
||||
// On success, returns the newly created library handle and places OK in status.
|
||||
// The caller owns the library handle.
|
||||
//
|
||||
// On failure, returns nullptr and places an error status in status.
|
||||
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||
const char* library_filename, TF_Status* status);
|
||||
|
||||
// Frees the memory associated with the library handle.
|
||||
// Does NOT unload the library.
|
||||
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||
TF_Library* lib_handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
string lib_path =
|
||||
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||
"test_pluggable_device.so"));
|
||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
|
||||
TF_DeleteFunction(tf_function);
|
||||
return nullptr;
|
||||
}
|
||||
tf_function->graph_with_debug_info = &fn_body->graph;
|
||||
|
||||
for (const Node* n : fn_body->graph.nodes()) {
|
||||
tf_function->stack_traces[n->name()] = n->GetStackTrace();
|
||||
}
|
||||
|
||||
return tf_function;
|
||||
}
|
||||
|
||||
|
@ -157,9 +157,7 @@ struct TF_DeviceList {
|
||||
|
||||
struct TF_Function {
|
||||
tensorflow::FunctionDef fdef;
|
||||
|
||||
// Graph with nodes with debug stack traces.
|
||||
const tensorflow::Graph* graph_with_debug_info = nullptr;
|
||||
tensorflow::StackTracesMap stack_traces;
|
||||
};
|
||||
|
||||
struct TF_ApiDefMap {
|
||||
|
@ -749,8 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
|
||||
function->fdef, function->graph_with_debug_info);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
|
||||
function->fdef, function->stack_traces);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
|
@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
|
||||
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
|
||||
// which has nodes containing stack traces for the nodes in `fdef`. Assumes
|
||||
// `graph` is alive while the function is alive.
|
||||
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
|
||||
const Graph* graph) = 0;
|
||||
// Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
|
||||
// the key of the function definition name (to be retrieved during function
|
||||
// instantiation).
|
||||
virtual Status AddFunctionDefWithStackTraces(
|
||||
const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
|
||||
|
||||
// Find and return a added function by its name.
|
||||
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||
|
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,17 @@
|
||||
# Description:
|
||||
# test for stream_executor
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_shared_object",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "test_pluggable_device.so",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
xla::StatusOr<pybind11::object> Bfloat16Dtype();
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
}
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
// and with a focus on maintaining both source and binary compatibility.
|
||||
@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
const tensorflow::AttrValue* attr =
|
||||
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
|
||||
if (attr == nullptr) {
|
||||
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
|
||||
"' has no attr named '", attr_name, "'.");
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
|
||||
const char* attr_name,
|
||||
const TF_DataType type,
|
||||
@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
cc_ctx->CtxFailure(s);
|
||||
}
|
||||
|
||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
|
||||
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
int32_t* list_size,
|
||||
int32_t* total_size,
|
||||
TF_Status* status) {
|
||||
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
|
||||
if (!status->status.ok()) {
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
return;
|
||||
}
|
||||
switch (attr->value_case()) {
|
||||
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
||||
case tensorflow::AttrValue::kK: \
|
||||
*list_size = -1; \
|
||||
*total_size = size_expr; \
|
||||
break;
|
||||
|
||||
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
|
||||
SINGLE_CASE(kI, TF_ATTR_INT, -1);
|
||||
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
|
||||
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
|
||||
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
|
||||
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
|
||||
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
|
||||
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
|
||||
#undef SINGLE_CASE
|
||||
|
||||
case tensorflow::AttrValue::kList:
|
||||
*list_size = 0;
|
||||
*total_size = -1;
|
||||
#define LIST_CASE(field, attr_type, ...) \
|
||||
if (attr->list().field##_size() > 0) { \
|
||||
*list_size = attr->list().field##_size(); \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
}
|
||||
|
||||
LIST_CASE(
|
||||
s, TF_ATTR_STRING, *total_size = 0;
|
||||
for (int i = 0; i < attr->list().s_size();
|
||||
++i) { *total_size += attr->list().s(i).size(); });
|
||||
LIST_CASE(i, TF_ATTR_INT);
|
||||
LIST_CASE(f, TF_ATTR_FLOAT);
|
||||
LIST_CASE(b, TF_ATTR_BOOL);
|
||||
LIST_CASE(type, TF_ATTR_TYPE);
|
||||
LIST_CASE(
|
||||
shape, TF_ATTR_SHAPE, *total_size = 0;
|
||||
for (int i = 0; i < attr->list().shape_size(); ++i) {
|
||||
const auto& s = attr->list().shape(i);
|
||||
*total_size += s.unknown_rank() ? 0 : s.dim_size();
|
||||
});
|
||||
LIST_CASE(tensor, TF_ATTR_TENSOR);
|
||||
LIST_CASE(tensor, TF_ATTR_FUNC);
|
||||
#undef LIST_CASE
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::kPlaceholder:
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::kFunc:
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::VALUE_NOT_SET:
|
||||
status->status =
|
||||
InvalidArgument("Attribute '", attr_name, "' has no value set");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
|
||||
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
|
||||
const char* attr_name, \
|
||||
c_type* val, TF_Status* status) { \
|
||||
@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
if (s.ok()) { \
|
||||
*val = static_cast<c_type>(v); \
|
||||
} \
|
||||
} \
|
||||
void TF_OpKernelConstruction_GetAttr##func##List( \
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
|
||||
int max_vals, TF_Status* status) { \
|
||||
TF_SetStatus(status, TF_OK, ""); \
|
||||
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
return; \
|
||||
} \
|
||||
status->status = \
|
||||
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
|
||||
if (!status->status.ok()) return; \
|
||||
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
|
||||
for (int i = 0; i < len; ++i) { \
|
||||
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
||||
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
|
||||
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
|
||||
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
|
||||
DEFINE_TF_GETATTR(Float, float, float, "float", f)
|
||||
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
|
||||
|
||||
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name, char* value,
|
||||
size_t max_length,
|
||||
TF_Status* status) {
|
||||
std::string v;
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
if (max_length <= 0) {
|
||||
return;
|
||||
}
|
||||
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
|
||||
}
|
||||
|
||||
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
char** values, size_t* lengths,
|
||||
int max_values, void* storage,
|
||||
size_t storage_size,
|
||||
TF_Status* status) {
|
||||
std::vector<std::string> v;
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
const auto len = std::min(max_values, static_cast<int>(v.size()));
|
||||
char* p = static_cast<char*>(storage);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
const std::string& s = v[i];
|
||||
values[i] = p;
|
||||
lengths[i] = s.size();
|
||||
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
|
||||
status->status = InvalidArgument(
|
||||
"Not enough storage to hold the requested list of strings");
|
||||
return;
|
||||
}
|
||||
memcpy(values[i], s.data(), s.size());
|
||||
p += s.size();
|
||||
}
|
||||
}
|
||||
|
||||
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name, TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
return cc_ctx->HasAttr(attr_name);
|
||||
}
|
||||
|
||||
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
|
||||
|
@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
|
||||
// Returns the step ID of the given context.
|
||||
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
|
||||
|
||||
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
|
||||
// list_size - the length of the list.
|
||||
// total_size - total size of the list.
|
||||
// (1) If attr_type == TF_ATTR_STRING
|
||||
// then total_size is the cumulative byte size
|
||||
// of all the strings in the list.
|
||||
// (3) If attr_type == TF_ATTR_SHAPE
|
||||
// then total_size is the number of dimensions
|
||||
// of the shape valued attribute, or -1
|
||||
// if its rank is unknown.
|
||||
// (4) If attr_type == TF_ATTR_SHAPE
|
||||
// then total_size is the cumulative number
|
||||
// of dimensions of all shapes in the list.
|
||||
// (5) Otherwise, total_size is undefined.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
|
||||
int32_t* total_size, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as a TF_DataType and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int64_t and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// int64, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as float and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// float, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as bool and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// bool, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as string and
|
||||
// places it into *val. `val` must
|
||||
// point to an array of length at least `max_length` (ideally set to
|
||||
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
|
||||
// attr_name, list_size, total_size)). *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// string, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
|
||||
size_t max_length, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as a TF_DataType array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int32_t array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int64_t array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as float array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as bool array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as string array and fills
|
||||
// in `vals` and `lengths`, each of which must point to an array of length at
|
||||
// least `max_values`. *status is set to TF_OK. The elements of values will
|
||||
// point to addresses in `storage` which must be at least `storage_size` bytes
|
||||
// in length. Ideally, max_values would be set to list_size and `storage` would
|
||||
// be at least total_size, obtained from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
|
||||
size_t* lengths, int max_values, void* storage, size_t storage_size,
|
||||
TF_Status* status);
|
||||
|
||||
// Return true if the kernel construction has the attr_name
|
||||
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
|
||||
|
||||
// Returns the unique operation name for this OpKernel.
|
||||
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
|
||||
TF_OpKernelConstruction* ctx);
|
||||
|
@ -161,6 +161,337 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
|
||||
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
|
||||
// Registers two ops, each with a single attribute called 'Attr'.
|
||||
// The attribute in one op will have a type 'type', the other
|
||||
// will have list(type).
|
||||
#define ATTR_TEST_REGISTER_OP(name, type) \
|
||||
REGISTER_OP("TestKernelAttr" #name) \
|
||||
.Attr("Attr: " #type) \
|
||||
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
|
||||
REGISTER_OP("TestKernelAttr" #name "List") \
|
||||
.Attr("Attr: list(" #type ")") \
|
||||
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
|
||||
ATTR_TEST_REGISTER_OP(String, string);
|
||||
ATTR_TEST_REGISTER_OP(Int, int);
|
||||
ATTR_TEST_REGISTER_OP(Float, float);
|
||||
ATTR_TEST_REGISTER_OP(Bool, bool);
|
||||
ATTR_TEST_REGISTER_OP(Type, type);
|
||||
#undef ATTR_TEST_REGISTER_OP
|
||||
|
||||
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
|
||||
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
|
||||
do { \
|
||||
int32_t list_size, total_size; \
|
||||
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
|
||||
&total_size, status); \
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
|
||||
EXPECT_EQ(expected_list_size, list_size); \
|
||||
EXPECT_EQ(expected_total_size, total_size); \
|
||||
} while (0)
|
||||
|
||||
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
|
||||
class TestKernelAttr : public ::testing::Test {
|
||||
public:
|
||||
TestKernelAttr() {}
|
||||
~TestKernelAttr() {}
|
||||
|
||||
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
|
||||
AttrValue v, Status* status) {
|
||||
NodeDef def;
|
||||
def.set_op(op_name);
|
||||
def.set_name("FakeNode");
|
||||
def.set_device("FakeDevice");
|
||||
(*def.mutable_attr())["Attr"] = v;
|
||||
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
|
||||
status);
|
||||
}
|
||||
|
||||
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
|
||||
AttrValue& v) {
|
||||
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
|
||||
{
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder("FakeNode", builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> kernel =
|
||||
GetFakeKernelWithAttr(op_name, v, &status);
|
||||
TF_EXPECT_OK(status);
|
||||
ASSERT_NE(nullptr, kernel.get());
|
||||
kernel->Compute(nullptr);
|
||||
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestKernelAttr, String) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
std::unique_ptr<char[]> val(new char[5]);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ 5);
|
||||
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
|
||||
/*max_length*/ 5, status);
|
||||
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_s("bunny");
|
||||
SetAttr(my_create_func, "TestKernelAttrString", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, StringList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
std::vector<string> list = {"bugs", "bunny", "duck"};
|
||||
int list_total_size = 0;
|
||||
for (const auto& s : list) {
|
||||
list_total_size += s.size();
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
std::unique_ptr<char*[]> values(new char*[list.size()]);
|
||||
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
|
||||
std::unique_ptr<char[]> storage(new char[list_total_size]);
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
|
||||
/*expected_total_size*/ list_total_size);
|
||||
TF_OpKernelConstruction_GetAttrStringList(
|
||||
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
|
||||
list_total_size, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
for (size_t i = 0; i < list.size(); ++i) {
|
||||
EXPECT_EQ(list[i].size(), lens[i]) << i;
|
||||
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
|
||||
<< i;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrStringList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Int) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
int64_t val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(1234, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_i(1234);
|
||||
SetAttr(my_create_func, "TestKernelAttrInt", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, IntList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const int64_t list[] = {1, 2, 3, 4};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
int64_t values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrIntList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Float) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
float val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FLOAT_EQ(2.718, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_f(2.718);
|
||||
SetAttr(my_create_func, "TestKernelAttrFloat", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, FloatList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const float list[] = {1.414, 2.718, 3.1415};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
float values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Bool) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
unsigned char val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(1, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_b(1);
|
||||
SetAttr(my_create_func, "TestKernelAttrBool", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, BoolList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const unsigned char list[] = {1, 0, 1, 0};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
unsigned char values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Type) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
TF_DataType val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_FLOAT, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_type(DT_FLOAT);
|
||||
SetAttr(my_create_func, "TestKernelAttrType", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, TypeList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
TF_DataType values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in =
|
||||
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
|
||||
}
|
||||
#undef EXPECT_TF_SIZE
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
|
@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
|
||||
MAP_HLO_TO_LHLO(MulOp);
|
||||
MAP_HLO_TO_LHLO(NegOp);
|
||||
MAP_HLO_TO_LHLO(NotOp);
|
||||
MAP_HLO_TO_LHLO(OrOp);
|
||||
MAP_HLO_TO_LHLO(RealOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
|
||||
MAP_HLO_TO_LHLO(SubOp);
|
||||
MAP_HLO_TO_LHLO(TanhOp);
|
||||
MAP_HLO_TO_LHLO(TransposeOp);
|
||||
MAP_HLO_TO_LHLO(XorOp);
|
||||
|
||||
#undef MAP_HLO_TO_LHLO
|
||||
|
||||
|
@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
struct HloOpToStdScalarOp {
|
||||
|
@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<mhlo::MulOp>,
|
||||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::NotOp>,
|
||||
HloToLhloOpConverter<mhlo::OrOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||
HloToLhloOpConverter<mhlo::TransposeOp>,
|
||||
HloToLhloOpConverter<mhlo::XorOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloReturnOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
|
@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::XorOp>,
|
||||
ReduceConverter,
|
||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
@ -1055,11 +1059,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
|
@ -42,11 +42,11 @@ namespace {
|
||||
sep fn(SqrtOp) sep fn(TanhOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
||||
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
|
||||
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
|
||||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
||||
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
|
||||
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
|
||||
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
|
@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @and
|
||||
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @or
|
||||
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remainder
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @xor
|
||||
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Dynamic shape binary element-wise operation.
|
||||
// CHECK-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
|
@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_or
|
||||
func @integer_or(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: or
|
||||
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_xor
|
||||
func @integer_xor(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: xor
|
||||
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
|
@ -509,7 +509,7 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
|
||||
return op.getOperation();
|
||||
}
|
||||
auto op = builder.create<tfl::ConstOp>(loc, value);
|
||||
if (!tensor.quantization->min.empty()) {
|
||||
if (tensor.quantization && !tensor.quantization->min.empty()) {
|
||||
if (auto stats_op =
|
||||
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
|
||||
return stats_op;
|
||||
|
@ -1977,6 +1977,7 @@ cc_library(
|
||||
hdrs = ["utils/bridge_logger.h"],
|
||||
deps = [
|
||||
":dump_mlir_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -5028,6 +5028,26 @@ Table initializer that takes two tensors for keys and values respectively.
|
||||
TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> {
|
||||
let summary = "Adds v into specified rows of x.";
|
||||
|
||||
let description = [{
|
||||
Computes y = x; y[i, :] += v; return y.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
TF_Int32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = "Updates specified rows 'i' with values 'v'.";
|
||||
|
||||
@ -5374,6 +5394,37 @@ def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> {
|
||||
let summary = "Computes the Kth order statistic of a data set. The current";
|
||||
|
||||
let description = [{
|
||||
implementation uses a binary search requiring exactly 32 passes over
|
||||
the input data. The running time is linear with respect to input
|
||||
size. The median-of-medians algorithm is probably faster, but is
|
||||
difficult to implement efficiently in XLA. The implementation imposes
|
||||
a total ordering on floats. The ordering is consistent with the usual
|
||||
partial order. Positive NaNs are greater than positive
|
||||
infinity. Negative NaNs are less than negative infinity. NaNs with
|
||||
distinct payloads are treated as distinct. Subnormal numbers are
|
||||
preserved (not flushed to zero). Positive infinity is greater than all
|
||||
numbers. Negative infinity is less than all numbers. Positive is
|
||||
greater than negative zero. There are less than k values greater than
|
||||
the kth order statistic. There are at least k values greater than or
|
||||
equal to the Kth order statistic. The semantics are not the same as
|
||||
top_k_unique.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Float32Tensor:$input,
|
||||
|
||||
I64Attr:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Float32Tensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
||||
let summary = "L2 Loss.";
|
||||
|
||||
@ -6505,6 +6556,27 @@ iterator in `iterator` to the first element of `dataset`.
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Make all elements in the non-Batch dimension unique, but \"close\" to
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
their initial value. Never returns a sub-normal number. Never returns
|
||||
zero. The sign of each input element is always identical to the sign
|
||||
of the corresponding output element. Behavior for infinite elements is
|
||||
undefined. Behavior for subnormal elements is undefined.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Float32Tensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Float32Tensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||
let summary = [{
|
||||
Multiply the matrix "a" by the matrix "b".
|
||||
@ -15234,6 +15306,36 @@ array([[1, 2, 3, 1, 2, 3],
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> {
|
||||
let summary = "Returns the TopK unique values in the array in sorted order.";
|
||||
|
||||
let description = [{
|
||||
The running time is proportional to the product of K and the input
|
||||
size. Sorting the whole array is more efficient for sufficiently large
|
||||
values of K. The median-of-medians algorithm is probably faster, but
|
||||
difficult to implement efficiently in XLA. If there are fewer than K
|
||||
unique numbers (not NANs), the results are padded with negative
|
||||
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
||||
zero. If an element appears at multiple indices, the highest index is
|
||||
returned. If a TopK element never appears in the input due to padding
|
||||
values, the indices are padded with negative one. If a padding value
|
||||
appears in the input and padding is needed, the highest index of the
|
||||
padding value will be returned. The semantics are not the same as
|
||||
kth_order_statistic.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Float32Tensor:$input,
|
||||
|
||||
I64Attr:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Float32Tensor:$topk,
|
||||
TF_Int32Tensor:$topk_indices
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Finds values and indices of the `k` largest elements for the last dimension.
|
||||
@ -15269,6 +15371,29 @@ If two elements are equal, the lower-index element appears first.
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> {
|
||||
let summary = "Returns the TopK values in the array in sorted order.";
|
||||
|
||||
let description = [{
|
||||
This is a combination of MakeUnique and TopKUnique. The returned top-K will
|
||||
have its lower bits replaced by iota, thus it will be close to the original
|
||||
value but not exactly the same. The running time is proportional to the product
|
||||
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
|
||||
to zero.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Float32Tensor:$input,
|
||||
|
||||
I64Attr:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Float32Tensor:$topk,
|
||||
TF_Int32Tensor:$topk_indices
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
||||
let summary = "Shuffle dimensions of x according to a permutation.";
|
||||
|
||||
|
@ -532,7 +532,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto ranked_type = type.dyn_cast<RankedTensorType>();
|
||||
if (!ranked_type) return {};
|
||||
|
||||
auto output_type = getType().cast<ShapedType>();
|
||||
// DenseIntElementsAttr::get requires the output type be ranked with static
|
||||
// shape.
|
||||
auto output_type = getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_type || !output_type.hasStaticShape()) return {};
|
||||
|
||||
int32_t rank = ranked_type.getRank();
|
||||
return DenseIntElementsAttr::get(output_type, rank);
|
||||
}
|
||||
|
@ -904,6 +904,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput
|
||||
func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> {
|
||||
// Regression test to make sure we don't crash in this case.
|
||||
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput
|
||||
func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<?xi32> {
|
||||
// Regression test to make sure we don't crash in this case.
|
||||
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @foldFill
|
||||
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
||||
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
|
@ -1627,8 +1627,10 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
|
||||
return success();
|
||||
}
|
||||
int64_t producer = producer_or.ValueOrDie();
|
||||
// TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no
|
||||
// longer needed.
|
||||
ShapeInference context(producer, module.getContext(),
|
||||
/*propagate_caller_callee_constants=*/true);
|
||||
/*propagate_caller_callee_constants=*/false);
|
||||
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
|
||||
context.enqueue(main);
|
||||
for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
@ -23,17 +26,30 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Counter is used as a prefix for filenames.
|
||||
static std::atomic<int> log_counter(0);
|
||||
|
||||
BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
|
||||
bool print_after_only_on_change)
|
||||
: mlir::PassManager::IRPrinterConfig(print_module_scope,
|
||||
print_after_only_on_change) {}
|
||||
print_after_only_on_change) {
|
||||
const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
|
||||
if (log_pass_patterns) {
|
||||
log_pass_patterns_ =
|
||||
absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
|
||||
}
|
||||
}
|
||||
|
||||
// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`.
|
||||
// Logs op to file with name of format
|
||||
// `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
|
||||
inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
|
||||
mlir::Pass* pass, mlir::Operation* op,
|
||||
llvm::StringRef file_suffix) {
|
||||
std::string name =
|
||||
llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str();
|
||||
std::string pass_name = pass->getName().str();
|
||||
|
||||
// Add 4-digit counter as prefix so the order of the passes is obvious.
|
||||
std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
|
||||
pass_name, file_suffix);
|
||||
|
||||
std::unique_ptr<llvm::raw_ostream> os;
|
||||
std::string filepath;
|
||||
@ -44,13 +60,30 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
|
||||
void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
|
||||
mlir::Operation* operation,
|
||||
PrintCallbackFn print_callback) {
|
||||
Log(print_callback, pass, operation, "before");
|
||||
if (should_print(pass)) Log(print_callback, pass, operation, "before");
|
||||
}
|
||||
|
||||
void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
|
||||
mlir::Operation* operation,
|
||||
PrintCallbackFn print_callback) {
|
||||
Log(print_callback, pass, operation, "after");
|
||||
if (should_print(pass)) Log(print_callback, pass, operation, "after");
|
||||
}
|
||||
|
||||
bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
|
||||
if (log_pass_patterns_.empty()) return true;
|
||||
|
||||
std::string pass_name = pass->getName().str();
|
||||
for (const auto& pattern : log_pass_patterns_) {
|
||||
if (pass_name.find(pattern) != std::string::npos) {
|
||||
// pattern matches pass
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// no pattern matches pass
|
||||
VLOG(2) << "Not logging pass " << pass_name
|
||||
<< " because it does not match any pattern in "
|
||||
"MLIR_BRIDGE_LOG_PASS_PATTERNS";
|
||||
return false;
|
||||
}
|
||||
|
||||
void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {
|
||||
|
@ -23,7 +23,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Logger for logging/dumping MLIR modules before and after passes in bridge
|
||||
// targeting TPUs.
|
||||
// targeting TPUs. The passes being logged can be restricted via environment
|
||||
// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
|
||||
// separated list of strings, and only passes whose name contains any of those
|
||||
// strings as a substring are logged (no regex support). If
|
||||
// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
|
||||
class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
|
||||
public:
|
||||
explicit BridgeLoggerConfig(bool print_module_scope = false,
|
||||
@ -42,6 +46,14 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
|
||||
// with the stream to dump into.
|
||||
void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
|
||||
PrintCallbackFn print_callback) override;
|
||||
|
||||
private:
|
||||
bool should_print(mlir::Pass *pass);
|
||||
|
||||
// Only print passes that match any of these patterns. A pass matches a
|
||||
// pattern if its name contains the pattern as a substring. If
|
||||
// `log_pass_patterns_` is empty, print all passes.
|
||||
std::vector<std::string> log_pass_patterns_;
|
||||
};
|
||||
|
||||
// Logger for logging/dumping pass pipeline timings after completion.
|
||||
|
@ -516,6 +516,11 @@ class ConvertToHloModule {
|
||||
//
|
||||
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
|
||||
LogicalResult Run() {
|
||||
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!main)
|
||||
return module_.emitError(
|
||||
"conversion requires module with `main` function");
|
||||
|
||||
for (auto func : module_.getOps<FuncOp>()) {
|
||||
if (func.empty()) continue;
|
||||
if (failed(RunOnFunction(func))) return failure();
|
||||
@ -539,8 +544,11 @@ class ConvertToHloModule {
|
||||
xla::XlaComputation* result);
|
||||
|
||||
::xla::HloModuleProto ConsumeMainProto() {
|
||||
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
|
||||
.proto();
|
||||
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
|
||||
// This is an invariant check as Run returns failure if there is no main
|
||||
// function and so the main proto shouldn't be consumed in that case.
|
||||
CHECK(main) << "requires module to have main function"; // Crash Ok.
|
||||
return lowered_computation_[main].proto();
|
||||
}
|
||||
|
||||
// Lower function call to HLO call instruction
|
||||
|
@ -174,6 +174,13 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_xor
|
||||
func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: mhlo.xor
|
||||
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_and
|
||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: mhlo.and
|
||||
|
@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_unranked
|
||||
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: tf.FloorDiv
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_int
|
||||
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: tf.FloorDiv
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
||||
|
||||
// CHECK-LABEL: func @floormod_unranked
|
||||
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: tf.FloorMod
|
||||
// CHECK-NOT: tf.FloorMod
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
}
|
||||
|
@ -0,0 +1,7 @@
|
||||
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: conversion requires module with `main`
|
||||
func @non_main() {
|
||||
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
return
|
||||
}
|
@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
|
||||
// Performs a substitution of FloorDiv, pseudo code below:
|
||||
//
|
||||
// return floor(div(x, y))
|
||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||
(HLO_FloorOp
|
||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||
[(IEEEFloatTensor $l)]>;
|
||||
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||
// Requires static shaped inputs to create constant splats and computation of
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
||||
(HLO_SelectOp
|
||||
(HLOClient_BroadcastAndOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
@ -193,14 +193,15 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
|
||||
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||
[(SignedIntTensor $l)]>;
|
||||
|
||||
foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
|
||||
[TF_LogicalOrOp, HLOClient_BroadcastOrOp],
|
||||
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
|
||||
[TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
|
||||
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in
|
||||
[TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
|
||||
def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -154,9 +154,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::IgammaOp>(),
|
||||
TypeID::get<TF::IgammacOp>(),
|
||||
TypeID::get<TF::IgammaGradAOp>(),
|
||||
TypeID::get<TF::InplaceAddOp>(),
|
||||
TypeID::get<TF::InTopKV2Op>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::KthOrderStatisticOp>(),
|
||||
TypeID::get<TF::LRNOp>(),
|
||||
TypeID::get<TF::LRNGradOp>(),
|
||||
TypeID::get<TF::LeakyReluGradOp>(),
|
||||
@ -170,6 +172,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::LogicalOrOp>(),
|
||||
TypeID::get<TF::LogOp>(),
|
||||
TypeID::get<TF::LowerBoundOp>(),
|
||||
TypeID::get<TF::MakeUniqueOp>(),
|
||||
TypeID::get<TF::MatMulOp>(),
|
||||
TypeID::get<TF::MatrixDiagV3Op>(),
|
||||
TypeID::get<TF::MatrixInverseOp>(),
|
||||
@ -248,6 +251,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::TensorScatterAddOp>(),
|
||||
TypeID::get<TF::TensorScatterSubOp>(),
|
||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||
TypeID::get<TF::TopKUniqueOp>(),
|
||||
TypeID::get<TF::TopKWithUniqueOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TridiagonalSolveOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
|
@ -121,7 +121,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
||||
// Run all HLO passes to produce an optimized module.
|
||||
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
|
||||
std::move(hlo_module), backend->default_stream_executor(),
|
||||
backend->memory_allocator(), optimize_xla_hlo);
|
||||
optimize_xla_hlo, {backend->memory_allocator()});
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
|
||||
"running XLA pass pipeline");
|
||||
std::unique_ptr<HloModule> optimized_hlo_module =
|
||||
|
@ -115,6 +115,16 @@ class ExecutableBuildOptions {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Thread pool for parallel compilation.
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool() const {
|
||||
return compile_thread_pool_;
|
||||
}
|
||||
ExecutableBuildOptions& set_run_backend_only(
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool) {
|
||||
compile_thread_pool_ = compile_thread_pool;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_ordinal_ = -1;
|
||||
Shape result_layout_;
|
||||
@ -128,6 +138,7 @@ class ExecutableBuildOptions {
|
||||
absl::optional<DeviceAssignment> device_assignment_;
|
||||
bool alias_passthrough_params_ = false;
|
||||
bool run_backend_only_ = false;
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
// contant False if dimension is static.
|
||||
// - Reduce: Convert to reduce or.
|
||||
// - Constant: Convert to constant False.
|
||||
// - Reshape, slice, transpose, pad:
|
||||
// Convert into predicate type with same opcode.
|
||||
// - Other ops: Not supported.
|
||||
// Create the instruction for the new handle.
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad:
|
||||
break;
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
|
@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu";
|
||||
|
||||
CpuDevice::CpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: PjRtDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kCpuPlatformName) {}
|
||||
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kCpuPlatformName) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
devices.push_back(std::move(device));
|
||||
}
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
/*gpu_run_options=*/nullptr));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class CpuDevice : public PjRtDevice {
|
||||
class CpuDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
||||
};
|
||||
|
@ -35,9 +35,9 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
// A custom PjRtClient that overrides the device assignment method.
|
||||
class GpuClient : public xla::PjRtClient {
|
||||
class GpuClient : public xla::PjRtStreamExecutorClient {
|
||||
public:
|
||||
using xla::PjRtClient::PjRtClient;
|
||||
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
|
||||
|
||||
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
|
||||
return assignment;
|
||||
}
|
||||
// Fallback to default global device assignment if we can't run locally.
|
||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
// Builds an xla::LocalClient for the GPU platform.
|
||||
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
for (auto& local_device : local_device_states) {
|
||||
int device_ordinal = local_device->device_ordinal();
|
||||
const se::DeviceDescription& description =
|
||||
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||
Status BuildDistributedDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
|
||||
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
LocalTopologyProto local_topology;
|
||||
local_topology.set_node_id(node_id);
|
||||
@ -306,8 +307,8 @@ Status BuildDistributedDevices(
|
||||
GpuDevice::GpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int node_id)
|
||||
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
|
||||
node_id) {}
|
||||
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||
std::move(device_kind), node_id) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
||||
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
||||
auto host_memory_allocator =
|
||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
||||
if (distributed_client) {
|
||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class GpuDevice : public PjRtDevice {
|
||||
class GpuDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int node_id);
|
||||
|
@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
|
||||
|
||||
InterpreterDevice::InterpreterDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: PjRtDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(0).ValueOrDie();
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
/*gpu_run_options=*/nullptr));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class InterpreterDevice : public PjRtDevice {
|
||||
class InterpreterDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
InterpreterDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state);
|
||||
|
@ -114,21 +114,22 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
PjRtPlatformId PjRtDevice::platform_id() const {
|
||||
PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
|
||||
return client_->platform_id();
|
||||
}
|
||||
const std::string& PjRtDevice::platform_name() const {
|
||||
const std::string& PjRtStreamExecutorDevice::platform_name() const {
|
||||
return client_->platform_name();
|
||||
}
|
||||
|
||||
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
|
||||
StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
|
||||
const {
|
||||
if (local_device_state_) {
|
||||
return local_device_state_.get();
|
||||
}
|
||||
return InvalidArgument("Device %s is not a local device.", DebugString());
|
||||
}
|
||||
|
||||
std::string PjRtDevice::DebugString() const {
|
||||
std::string PjRtStreamExecutorDevice::DebugString() const {
|
||||
return absl::StrCat(platform_name(), ":", id());
|
||||
}
|
||||
|
||||
@ -153,14 +154,15 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
devices[replica].size(), replica, devices[0].size());
|
||||
}
|
||||
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
||||
if (devices[0][0]->platform_id() !=
|
||||
devices[replica][partition]->platform_id()) {
|
||||
if (devices[0][0]->client()->platform_id() !=
|
||||
devices[replica][partition]->client()->platform_id()) {
|
||||
return InvalidArgument(
|
||||
"Device assignment passed to Compile() must have devices of a "
|
||||
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
||||
"%d partition %d.",
|
||||
devices[0][0]->platform_name(),
|
||||
devices[replica][partition]->platform_name(), replica, partition);
|
||||
devices[0][0]->client()->platform_name(),
|
||||
devices[replica][partition]->client()->platform_name(), replica,
|
||||
partition);
|
||||
}
|
||||
xla_assignment(replica, partition) = devices[replica][partition]->id();
|
||||
}
|
||||
@ -182,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
}
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
@ -193,7 +195,7 @@ PjRtClient::PjRtClient(
|
||||
platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
owned_devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
should_stage_host_to_device_transfers_(
|
||||
@ -211,12 +213,14 @@ PjRtClient::PjRtClient(
|
||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
|
||||
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
|
||||
owned_devices_) {
|
||||
devices_.push_back(device.get());
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->IsLocalDevice()) {
|
||||
int idx = device->local_device_id();
|
||||
if (device->IsAddressable()) {
|
||||
int idx = device->local_hardware_id();
|
||||
if (idx >= local_devices_.size()) {
|
||||
local_devices_.resize(idx + 1);
|
||||
}
|
||||
@ -230,13 +234,14 @@ PjRtClient::PjRtClient(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
||||
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
||||
std::unique_ptr<HloCostAnalysis>
|
||||
PjRtStreamExecutorClient::GetHloCostAnalysis() {
|
||||
return absl::make_unique<HloCostAnalysis>(
|
||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
}
|
||||
@ -346,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||
}
|
||||
|
||||
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape, client->allocator(), local_device->device_ordinal()));
|
||||
se_client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape, se_client->allocator(),
|
||||
local_device->device_ordinal()));
|
||||
if (local_device->allocation_model() ==
|
||||
LocalDeviceState::kComputeSynchronized) {
|
||||
if (copy_stream == nullptr) {
|
||||
@ -543,18 +549,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
|
||||
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtStreamExecutorClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
|
||||
<< shape.ToString() << " device: " << device->DebugString();
|
||||
if (shape.IsTuple()) {
|
||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->GetLocalDeviceState());
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
@ -708,20 +717,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) {
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
|
||||
PjRtDevice* device) {
|
||||
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
|
||||
<< shape.ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->GetLocalDeviceState());
|
||||
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
@ -733,13 +745,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
definition_event);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtStreamExecutorClient::BufferFromHostLiteral");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
|
||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->GetLocalDeviceState());
|
||||
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -792,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) {
|
||||
if (shapes.empty()) {
|
||||
@ -801,7 +816,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
return;
|
||||
}
|
||||
|
||||
auto local_device_or = device->GetLocalDeviceState();
|
||||
auto local_device_or =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->GetLocalDeviceState();
|
||||
if (!local_device_or.ok()) {
|
||||
notifier(local_device_or.status());
|
||||
return;
|
||||
@ -828,27 +845,30 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
}
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given local device.
|
||||
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
|
||||
Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||
const LiteralSlice& literal) const {
|
||||
// Only support infeed to local device.
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferToInfeedLocal(
|
||||
literal, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
||||
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||
const Shape& shape) const {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferFromOutfeedLocal(
|
||||
shape, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
|
||||
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||
int local_hardware_id) const {
|
||||
for (auto* device : local_devices_) {
|
||||
if (local_device_id == device->local_device_id()) {
|
||||
if (local_hardware_id == device->local_hardware_id()) {
|
||||
return device;
|
||||
}
|
||||
}
|
||||
return InvalidArgument("No matching device found for local_device_id %d",
|
||||
local_device_id);
|
||||
return InvalidArgument("No matching device found for local_hardware_id %d",
|
||||
local_hardware_id);
|
||||
}
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
@ -873,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
|
||||
}
|
||||
|
||||
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||
return client_->client()
|
||||
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->GetByteSizeRequirement(on_device_shape_);
|
||||
@ -919,7 +940,9 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
|
||||
// the final set of usage events.
|
||||
events = device_buffer->LockUseAndTransferUsageEvents();
|
||||
}
|
||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
||||
LocalDeviceState* local_device_state =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state();
|
||||
if (wait_for_operations_to_complete) {
|
||||
// Block the host until all usage events have completed. Usage events
|
||||
// dominate definition events, so this also waits for the buffer to be
|
||||
@ -1080,7 +1103,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
}
|
||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
LocalDeviceState* local_device = device_->local_device_state();
|
||||
LocalDeviceState* local_device =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state();
|
||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||
const xla::Layout& host_layout =
|
||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
||||
@ -1122,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
host_value->value = std::make_shared<Literal>(host_shape);
|
||||
ShapedBuffer shaped_buffer =
|
||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, shaped_buffer, host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
host_value->status = done_status;
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->TransferLiteralFromDevice(stream, shaped_buffer,
|
||||
host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
host_value->status = done_status;
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
|
||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
@ -1156,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||
if (host_value == nullptr) {
|
||||
@ -1241,8 +1270,9 @@ PjRtBuffer::CopyToDeviceHelper(
|
||||
// StallStreamOnError only makes sure the destination device is ok, so
|
||||
// make sure that the src buffer remains valid until after any transfers
|
||||
// have completed.
|
||||
device_->local_device_state()->ThenRelease(transfer_stream,
|
||||
src_device_buffer);
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state()
|
||||
->ThenRelease(transfer_stream, src_device_buffer);
|
||||
}
|
||||
return copy_event_or.status();
|
||||
}
|
||||
@ -1265,14 +1295,20 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||
return dst_device->client()->BufferFromHostBuffer(
|
||||
literal->untyped_data(), literal->shape(),
|
||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
||||
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
|
||||
dst_device);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
||||
dst_device->GetLocalDeviceState());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
LocalDeviceState * dst_local_device,
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
||||
->GetLocalDeviceState());
|
||||
LocalDeviceState* transfer_local_device =
|
||||
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
|
||||
: dst_local_device;
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->EnqueueD2DTransfersOnSrcStream()
|
||||
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state()
|
||||
: dst_local_device;
|
||||
CHECK_EQ(dst_local_device->allocation_model(),
|
||||
transfer_local_device->allocation_model());
|
||||
|
||||
@ -1310,7 +1346,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
// alternative is to ensure, before freeing the buffer, that the compute
|
||||
// stream is synchronized past the transfer, but it seems better to hold onto
|
||||
// the buffer too long than to stall the compute stream.
|
||||
RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
|
||||
RecordUsage(std::move(src_device_buffer),
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state(),
|
||||
transfer_local_device, event, transfer_stream,
|
||||
/*prefer_to_retain_reference=*/true);
|
||||
|
||||
@ -1318,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
}
|
||||
|
||||
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
||||
return client_->CopyToRemoteDevice(this, serialized_descriptor);
|
||||
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->CopyToRemoteDevice(this, serialized_descriptor);
|
||||
}
|
||||
|
||||
Status PjRtBuffer::BlockHostUntilReady() {
|
||||
@ -1332,7 +1371,9 @@ Status PjRtBuffer::BlockHostUntilReady() {
|
||||
}
|
||||
device_buffer = device_buffer_;
|
||||
}
|
||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
||||
LocalDeviceState* local_device_state =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state();
|
||||
std::unique_ptr<se::Stream> stream;
|
||||
for (auto& event : device_buffer->definition_events()) {
|
||||
if (!event->IsComplete()) {
|
||||
@ -1378,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
||||
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
||||
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager();
|
||||
se::Stream* stream = local_device->host_to_device_stream();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::OwningDeviceMemory root_table_memory,
|
||||
@ -1444,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
||||
/*prefer_to_retain_reference=*/false);
|
||||
return pjrt_buffer;
|
||||
}
|
||||
|
||||
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
|
||||
auto it = client.id_to_device().find(device_id);
|
||||
CHECK(it != client.id_to_device().end())
|
||||
<< "Unknown device id: " << device_id;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
@ -1459,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
|
||||
std::vector<PjRtDevice*> addressable_devices,
|
||||
PjRtStreamExecutorClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||
@ -1482,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
||||
<< device_assignment_->ToString();
|
||||
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
||||
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
|
||||
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
|
||||
<< "Inconsistent local device count.";
|
||||
num_partitions = device_assignment_->computation_count();
|
||||
}
|
||||
@ -1584,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
// Lift tuple_handle outside the conditional so that the event it returns is
|
||||
// not destroyed until after the loop below that waits on events.
|
||||
absl::optional<TupleHandle> tuple_handle;
|
||||
@ -1607,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
execution_input.MutableBuffers()->begin();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||
execution_input.MutableBuffers()->end();
|
||||
device_buffers[i].AddToInput(&input_iterator, iterator_end,
|
||||
&execution_input, client_->allocator());
|
||||
device_buffers[i].AddToInput(
|
||||
&input_iterator, iterator_end, &execution_input,
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->allocator());
|
||||
CHECK(input_iterator == iterator_end);
|
||||
}
|
||||
}
|
||||
@ -1628,8 +1668,10 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
|
||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->local_device_state()
|
||||
->device_ordinal();
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
tensorflow::profiler::TraceMeConsumer activity(
|
||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||
run_id.ToInt());
|
||||
@ -1765,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||
PjRtDevice* device) const {
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
||||
outputs.reserve(tuple_count);
|
||||
@ -1802,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
if (device == nullptr) {
|
||||
CHECK(device_assignment_ != nullptr);
|
||||
const int device_id = (*device_assignment_)(replica, partition);
|
||||
device = LookupDevice(*client_, device_id);
|
||||
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
|
||||
device_assignment = device_assignment_;
|
||||
} else {
|
||||
CHECK(device_assignment_ == nullptr);
|
||||
@ -1814,7 +1856,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
}
|
||||
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->local_device_state()
|
||||
->device_ordinal();
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
@ -1836,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
ScopedShapedBuffer result_buffer =
|
||||
result_buffer_or_status.ConsumeValueOrDie();
|
||||
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
se::Stream* stream = device_state->compute_stream();
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||
@ -1922,7 +1966,9 @@ PjRtStreamExecutorExecutable::Execute(
|
||||
const int replica = addressable_device_logical_ids_[i].replica;
|
||||
const int partition = addressable_device_logical_ids_[i].partition;
|
||||
PjRtDevice* device = addressable_devices_[i];
|
||||
const LocalDeviceState& device_state = *device->local_device_state();
|
||||
const LocalDeviceState& device_state =
|
||||
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->local_device_state();
|
||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||
run_id, options);
|
||||
@ -2131,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||
|
||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||
if (!build_options.device_allocator()) {
|
||||
@ -2153,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
|
||||
"device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
num_replicas = build_options.device_assignment().replica_count();
|
||||
num_partitions = build_options.device_assignment().computation_count();
|
||||
@ -2234,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
PjRtDevice* device = LookupDevice(*this, device_id);
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
|
||||
if (device->host_id() != host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
@ -2254,7 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
|
||||
if (build_options.device_ordinal() < 0) {
|
||||
build_options.set_device_ordinal(
|
||||
addressable_devices.front()->local_device_state()->device_ordinal());
|
||||
addressable_devices.front()->local_hardware_id());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -67,16 +68,56 @@ class PjRtClient;
|
||||
|
||||
class PjRtDevice {
|
||||
public:
|
||||
explicit PjRtDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int host_id = 0)
|
||||
virtual ~PjRtDevice() {}
|
||||
|
||||
// Return the client that owns this device.
|
||||
virtual PjRtClient* client() const = 0;
|
||||
|
||||
// Whether client can issue command to this device.
|
||||
virtual bool IsAddressable() const = 0;
|
||||
|
||||
// The ID of this device. IDs are unique among devices of this type
|
||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||
virtual int id() const = 0;
|
||||
|
||||
// The task ID of this device according to TpuTopology. This is not the same
|
||||
// as PjRtClient::host_id() in a multi-task setting, where each client can see
|
||||
// devices from all tasks, but only a subset of them are addressable and have
|
||||
// the same task_id as the client.
|
||||
virtual int host_id() const = 0;
|
||||
|
||||
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
|
||||
// which GPU when interacting with non-JAX code. In general, not guaranteed to
|
||||
// be dense, and -1 if undefined.
|
||||
virtual int local_hardware_id() const = 0;
|
||||
|
||||
// A vendor-dependent string that uniquely identifies the kind of device,
|
||||
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
|
||||
// compatible compilation.
|
||||
virtual const std::string& device_kind() const = 0;
|
||||
|
||||
virtual std::string DebugString() const = 0;
|
||||
|
||||
// Transfer the given literal to the infeed queue.
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed queue.
|
||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
||||
};
|
||||
|
||||
class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
public:
|
||||
explicit PjRtStreamExecutorDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int host_id = 0)
|
||||
: id_(id),
|
||||
local_device_id_(
|
||||
device_ordinal_(
|
||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||
local_device_state_(std::move(local_device_state)),
|
||||
host_id_(host_id),
|
||||
device_kind_(std::move(device_kind)) {}
|
||||
virtual ~PjRtDevice() {}
|
||||
~PjRtStreamExecutorDevice() override {}
|
||||
|
||||
// Must set client exactly once.
|
||||
void SetClient(PjRtClient* client) {
|
||||
@ -84,14 +125,25 @@ class PjRtDevice {
|
||||
client_ = client;
|
||||
}
|
||||
|
||||
// Task ID. This is always 0 on single-task setup.
|
||||
int host_id() const override { return host_id_; }
|
||||
|
||||
// Return `platform_id` from client.
|
||||
PjRtPlatformId platform_id() const;
|
||||
|
||||
// Return `platform_name` from client.
|
||||
const std::string& platform_name() const;
|
||||
|
||||
PjRtClient* client() const override { return client_; }
|
||||
|
||||
// The ID of this device. IDs are unique among devices of this type
|
||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||
int id() const { return id_; }
|
||||
int id() const override { return id_; }
|
||||
|
||||
bool IsLocalDevice() const { return local_device_id_ != -1; }
|
||||
bool IsAddressable() const override { return device_ordinal_ != -1; }
|
||||
|
||||
int local_device_id() const { return local_device_id_; }
|
||||
int local_hardware_id() const override { return device_ordinal_; }
|
||||
|
||||
// If this is a device local to this host, returns a LocalDeviceState object
|
||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||
@ -105,32 +157,21 @@ class PjRtDevice {
|
||||
// is not local to this host.
|
||||
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
||||
|
||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
||||
int host_id() const { return host_id_; }
|
||||
|
||||
// Return `platform_id` from client.
|
||||
PjRtPlatformId platform_id() const;
|
||||
|
||||
// Return `platform_name` from client.
|
||||
const std::string& platform_name() const;
|
||||
|
||||
// A vendor-dependent string that uniquely identifies the kind of device.
|
||||
const std::string& device_kind() const { return device_kind_; }
|
||||
const std::string& device_kind() const override { return device_kind_; }
|
||||
|
||||
virtual std::string DebugString() const;
|
||||
|
||||
PjRtClient* client() const { return client_; }
|
||||
std::string DebugString() const override;
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given localdevice.
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
|
||||
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed of the
|
||||
// given device.
|
||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
const int local_device_id_; // -1 means not local.
|
||||
const int device_ordinal_; // -1 means not local.
|
||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||
const int host_id_;
|
||||
const std::string device_kind_;
|
||||
@ -178,86 +219,62 @@ class PjRtExecutable;
|
||||
// alive as long as any of the other runtime objects are alive.
|
||||
class PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
// TODO(zhangqiaorjc): Rename to task_id.
|
||||
// Return the task id of this client. In single-task setting, always 0.
|
||||
virtual int host_id() const = 0;
|
||||
|
||||
// Return the number of devices in the entire computation. In multi-headed
|
||||
// client setting, some are addressable by this client, some are not. In a
|
||||
// single-client setting, this is equal to the number of addressable devices.
|
||||
virtual int device_count() const = 0;
|
||||
|
||||
// Return number of addressable devices. Addressable devices are those that
|
||||
// the client can issue commands to.
|
||||
virtual int addressable_device_count() const = 0;
|
||||
|
||||
// Return all devices in the entire computation, including addressable and
|
||||
// non-addressable devices.
|
||||
virtual absl::Span<PjRtDevice* const> devices() const = 0;
|
||||
|
||||
// TODO(zhangqiaorjc): Rename to addressable_devices.
|
||||
// Return only addressable devices.
|
||||
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
|
||||
|
||||
// Lookup any PjRtDevice for a given PjRtDevice::id().
|
||||
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
|
||||
|
||||
// Return an addressable PjRtDevice for a given
|
||||
// PjRtDevice::local_hardware_id().
|
||||
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||
int local_hardware_id) const = 0;
|
||||
|
||||
// Return an ID that identifies the platform (CPU/GPU/TPU).
|
||||
virtual PjRtPlatformId platform_id() const = 0;
|
||||
|
||||
// Returns a string that identifies the platform (CPU/GPU/TPU).
|
||||
virtual const std::string& platform_name() const = 0;
|
||||
|
||||
// Return a device-specific default device assignment, e.g., GPU and TPU may
|
||||
// be different.
|
||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const;
|
||||
|
||||
int device_count() const { return devices_.size(); }
|
||||
int local_device_count() const { return local_devices_.size(); }
|
||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
||||
return devices_;
|
||||
}
|
||||
const std::vector<PjRtDevice*>& local_devices() const {
|
||||
return local_devices_;
|
||||
}
|
||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
PjRtPlatformId platform_id() const { return platform_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *local_devices_.at(device_ordinal)->local_device_state();
|
||||
}
|
||||
|
||||
// Return a local PjRtDevice for a given `local_device_id`.
|
||||
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
|
||||
|
||||
LocalClient* client() const { return client_; }
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
|
||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
// Generates a unique fingerprint for `executable`.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
int num_replicas, int num_partitions) const = 0;
|
||||
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
|
||||
|
||||
// Compile `computation` with given `options`.
|
||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
const XlaComputation& computation, CompileOptions options) = 0;
|
||||
|
||||
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const = 0;
|
||||
|
||||
// Creates a buffer on the device without initializing or copying any data.
|
||||
// An optional `definition_event` may be speficied that can be used to
|
||||
// ensure the buffer isn't referenced until some external mechanism has
|
||||
// initialized the data.
|
||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||
// future backends and so callers should avoid wherever possible.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device);
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||
const Shape& shape, PjRtDevice* device) = 0;
|
||||
|
||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
@ -289,13 +306,13 @@ class PjRtClient {
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device);
|
||||
const LiteralSlice& literal, PjRtDevice* device) = 0;
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
@ -308,18 +325,140 @@ class PjRtClient {
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
virtual void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
PjRtCrossHostRecvNotifier&& notifier) = 0;
|
||||
|
||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
// Create ChannelHandles for XLA send/recv.
|
||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
|
||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
|
||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
||||
};
|
||||
|
||||
class PjRtStreamExecutorClient : public PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
~PjRtStreamExecutorClient() override = default;
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
|
||||
int device_count() const override { return devices_.size(); }
|
||||
int addressable_device_count() const override {
|
||||
return local_devices_.size();
|
||||
}
|
||||
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
||||
absl::Span<PjRtDevice* const> local_devices() const override {
|
||||
return local_devices_;
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
||||
auto it = id_to_device_.find(device_id);
|
||||
if (it != id_to_device_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return InvalidArgument("No matching device found for device_id %d",
|
||||
device_id);
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||
int local_hardware_id) const override;
|
||||
|
||||
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||
const std::string& platform_name() const override { return platform_name_; }
|
||||
|
||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options) override;
|
||||
|
||||
// Generates a unique fingerprint for `executable`.
|
||||
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const override {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
||||
|
||||
// Creates a buffer on the device without initializing or copying any data.
|
||||
// An optional `definition_event` may be speficied that can be used to
|
||||
// ensure the buffer isn't referenced until some external mechanism has
|
||||
// initialized the data.
|
||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||
// future backends and so callers should avoid wherever possible.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) override;
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||
// sent. When resources for the transfer are available, notifier will be
|
||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||
// received value, and an opaque string that should be transmitted to the
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) override;
|
||||
|
||||
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
||||
return client()->CreateChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
||||
return client()->CreateDeviceToHostChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
||||
return client()->CreateHostToDeviceChannelHandle();
|
||||
}
|
||||
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||
local_devices_.at(device_ordinal))
|
||||
->local_device_state();
|
||||
}
|
||||
LocalClient* client() const { return client_; }
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -342,7 +481,9 @@ class PjRtClient {
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices_;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
||||
// Pointers to `owned_devices_`.
|
||||
std::vector<PjRtDevice*> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
std::map<int, PjRtDevice*> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
@ -509,7 +650,7 @@ class PjRtBuffer {
|
||||
|
||||
private:
|
||||
friend class PjRtBuffer;
|
||||
friend class PjRtClient;
|
||||
friend class PjRtStreamExecutorClient;
|
||||
|
||||
// Helper struct that makes it possible to move a ScopedHold through a
|
||||
// closure.
|
||||
@ -769,7 +910,7 @@ class PjRtExecutable {
|
||||
virtual PjRtClient* client() const = 0;
|
||||
|
||||
// Unique name for this executable, e.g., HloModule name.
|
||||
virtual const string& name() const = 0;
|
||||
virtual const std::string& name() const = 0;
|
||||
|
||||
virtual int num_replicas() const = 0;
|
||||
|
||||
@ -791,6 +932,7 @@ class PjRtExecutable {
|
||||
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||
const = 0;
|
||||
|
||||
// An addressable_device is one which the client can issue commands to.
|
||||
// addressable_devices()[i] is the Device to which
|
||||
// addressable_device_logical_ids()[i] is assigned.
|
||||
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
||||
@ -833,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
|
||||
std::vector<PjRtDevice*> addressable_devices,
|
||||
PjRtStreamExecutorClient* client);
|
||||
|
||||
~PjRtStreamExecutorExecutable() override = default;
|
||||
|
||||
PjRtClient* client() const override { return client_; }
|
||||
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||
|
||||
const string& name() const override;
|
||||
const std::string& name() const override;
|
||||
|
||||
int num_replicas() const override {
|
||||
return executables_[0]->build_options().num_replicas();
|
||||
@ -898,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
}
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
friend class PjRtStreamExecutorClient;
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(bool tuple_inputs);
|
||||
@ -933,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
// Create shared pointers so we can free them after the execution: with
|
||||
// asynchronous execution, the process being executed can outlive the
|
||||
// executable itself.
|
||||
PjRtClient* const client_;
|
||||
PjRtStreamExecutorClient* const client_;
|
||||
// One executable per partition.
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
// Per-executable set of parameters that have any aliased buffers and thus
|
||||
|
@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class PjRtTpuClient : public PjRtClient {
|
||||
class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
public:
|
||||
PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id);
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
|
||||
const PjRtExecutable& executable) const override;
|
||||
};
|
||||
|
||||
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
||||
int host_id)
|
||||
: PjRtClient(kTpuName, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr) {}
|
||||
PjRtTpuClient::PjRtTpuClient(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
|
||||
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr) {}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
num_partitions);
|
||||
}
|
||||
// Fallback to default global device assignment if we can't run locally.
|
||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
return absl::optional<std::string>(tpu_executable->fingerprint());
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
tf_tpu::TpuTopologyExternal topology =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||
|
||||
|
@ -26,14 +26,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class PjRtTpuDevice : public PjRtDevice {
|
||||
class PjRtTpuDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
int host_id, const std::array<int, 3>& coords,
|
||||
std::string device_kind)
|
||||
: PjRtDevice(core.Id(), std::move(local_device_state),
|
||||
std::move(device_kind), host_id),
|
||||
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
|
||||
std::move(device_kind), host_id),
|
||||
core_(core),
|
||||
coords_(coords) {}
|
||||
|
||||
|
@ -97,13 +97,13 @@ cc_library(
|
||||
name = "types",
|
||||
srcs = ["types.cc"],
|
||||
hdrs = ["types.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":bfloat16",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -113,6 +113,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:bfloat16_lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@ -158,42 +159,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16",
|
||||
srcs = ["bfloat16.cc"],
|
||||
hdrs = ["bfloat16.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core/platform:bfloat16",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bfloat16_test",
|
||||
srcs = ["bfloat16_test.py"],
|
||||
main = "bfloat16_test.py",
|
||||
python_version = "PY3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
":xla_client",
|
||||
":xla_extension",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
] + xla_py_test_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_client",
|
||||
srcs = [
|
||||
@ -206,6 +171,7 @@ cc_library(
|
||||
"py_client.h",
|
||||
"py_executable.h",
|
||||
],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -232,6 +198,7 @@ cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
hdrs = ["dlpack.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -263,6 +230,7 @@ cc_library(
|
||||
name = "jax_jit",
|
||||
srcs = ["jax_jit.cc"],
|
||||
hdrs = ["jax_jit.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -292,6 +260,7 @@ cc_library(
|
||||
name = "ops",
|
||||
srcs = ["ops.cc"],
|
||||
hdrs = ["ops.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -356,6 +325,7 @@ cc_library(
|
||||
name = "outfeed_receiver_py",
|
||||
srcs = ["outfeed_receiver_py.cc"],
|
||||
hdrs = ["outfeed_receiver_py.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -379,12 +349,14 @@ cc_library(
|
||||
name = "pytree",
|
||||
srcs = ["pytree.cc"],
|
||||
hdrs = ["pytree.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":types",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
@ -435,6 +407,7 @@ cc_library(
|
||||
name = "xla_compiler",
|
||||
srcs = ["xla_compiler.cc"],
|
||||
hdrs = ["xla_compiler.h"],
|
||||
compatible_with = [],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -481,7 +454,6 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "xla_extension",
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":dlpack",
|
||||
":jax_jit",
|
||||
":ops",
|
||||
@ -534,6 +506,7 @@ pybind_extension(
|
||||
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
|
||||
# not require Tensorflow.
|
||||
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
|
||||
"//tensorflow/python:bfloat16_lib",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
] + select({
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,440 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for the bfloat16 Python type."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.python import xla_client
|
||||
|
||||
bfloat16 = xla_client.bfloat16
|
||||
|
||||
|
||||
def numpy_assert_allclose(a, b, **kwargs):
|
||||
a = a.astype(np.float32) if a.dtype == bfloat16 else a
|
||||
b = b.astype(np.float32) if b.dtype == bfloat16 else b
|
||||
return np.testing.assert_allclose(a, b, **kwargs)
|
||||
|
||||
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
|
||||
# Values that should round trip exactly to float and back.
|
||||
FLOAT_VALUES = [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"),
|
||||
float("-inf"),
|
||||
float("nan")
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16Test(parameterized.TestCase):
|
||||
"""Tests the non-numpy Python methods of the bfloat16 type."""
|
||||
|
||||
def testRoundTripToFloat(self):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, float(bfloat16(v)))
|
||||
|
||||
def testRoundTripNumpyTypes(self):
|
||||
for dtype in [np.float16, np.float32, np.float64]:
|
||||
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
|
||||
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
|
||||
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
|
||||
np.testing.assert_equal(
|
||||
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
|
||||
|
||||
def testRoundTripToInt(self):
|
||||
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
|
||||
self.assertEqual(v, int(bfloat16(v)))
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + dtype.__name__,
|
||||
"dtype": dtype
|
||||
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
|
||||
def testRoundTripToNumpy(self, dtype):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(v, bfloat16(dtype(v)))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
|
||||
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
|
||||
if dtype != bfloat16:
|
||||
np.testing.assert_equal(
|
||||
np.array(FLOAT_VALUES, dtype),
|
||||
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
|
||||
|
||||
def testStr(self):
|
||||
self.assertEqual("0", str(bfloat16(0.0)))
|
||||
self.assertEqual("1", str(bfloat16(1.0)))
|
||||
self.assertEqual("-3.5", str(bfloat16(-3.5)))
|
||||
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("inf", str(bfloat16(float("inf"))))
|
||||
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
|
||||
self.assertEqual("nan", str(bfloat16(float("nan"))))
|
||||
|
||||
def testRepr(self):
|
||||
self.assertEqual("0", repr(bfloat16(0)))
|
||||
self.assertEqual("1", repr(bfloat16(1)))
|
||||
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
|
||||
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
|
||||
self.assertEqual("inf", repr(bfloat16(float("inf"))))
|
||||
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
|
||||
self.assertEqual("nan", repr(bfloat16(float("nan"))))
|
||||
|
||||
def testHash(self):
|
||||
self.assertEqual(0, hash(bfloat16(0.0)))
|
||||
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
|
||||
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
|
||||
|
||||
# Tests for Python operations
|
||||
def testNegate(self):
|
||||
for v in FLOAT_VALUES:
|
||||
np.testing.assert_equal(-v, float(-bfloat16(v)))
|
||||
|
||||
def testAdd(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
|
||||
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
|
||||
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
|
||||
|
||||
# Test type promotion against Numpy scalar values.
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
|
||||
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
|
||||
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
|
||||
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
|
||||
self.assertEqual(np.float32,
|
||||
type(bfloat16(3.5) + np.array(2.25, np.float32)))
|
||||
self.assertEqual(np.float32,
|
||||
type(np.array(3.5, np.float32) + bfloat16(2.25)))
|
||||
|
||||
def testSub(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
|
||||
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
|
||||
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
|
||||
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
|
||||
|
||||
def testMul(self):
|
||||
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
|
||||
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
|
||||
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
|
||||
|
||||
def testDiv(self):
|
||||
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
|
||||
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
|
||||
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
|
||||
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
|
||||
np.testing.assert_equal(
|
||||
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
|
||||
np.testing.assert_equal(
|
||||
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
|
||||
|
||||
def testLess(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
|
||||
|
||||
def testLessEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
|
||||
|
||||
def testGreater(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
|
||||
|
||||
def testGreaterEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
|
||||
|
||||
def testEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
|
||||
|
||||
def testNotEqual(self):
|
||||
for v in FLOAT_VALUES:
|
||||
for w in FLOAT_VALUES:
|
||||
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
|
||||
|
||||
def testNan(self):
|
||||
a = np.isnan(bfloat16(float("nan")))
|
||||
self.assertTrue(a)
|
||||
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
|
||||
|
||||
a = np.array([bfloat16(1.34375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=bfloat16)
|
||||
b = np.array(
|
||||
[bfloat16(1.3359375),
|
||||
bfloat16(1.4375),
|
||||
bfloat16(float("nan"))],
|
||||
dtype=bfloat16)
|
||||
numpy_assert_allclose(
|
||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(FLOAT_VALUES)
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
|
||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
||||
|
||||
UNARY_UFUNCS = [
|
||||
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
|
||||
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
|
||||
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
|
||||
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
|
||||
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
|
||||
]
|
||||
|
||||
BINARY_UFUNCS = [
|
||||
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
|
||||
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
|
||||
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
|
||||
]
|
||||
|
||||
BINARY_PREDICATE_UFUNCS = [
|
||||
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
|
||||
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
|
||||
]
|
||||
|
||||
|
||||
class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
"""Tests the NumPy integration of the bfloat16 type."""
|
||||
|
||||
def testDtype(self):
|
||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
||||
|
||||
def testDeepCopyDoesNotAlterHash(self):
|
||||
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
||||
# value of the type descriptor is not initialized correctly, a deep copy
|
||||
# can change the type hash.
|
||||
dtype = np.dtype(bfloat16)
|
||||
h = hash(dtype)
|
||||
_ = copy.deepcopy(dtype)
|
||||
self.assertEqual(h, hash(dtype))
|
||||
|
||||
def testArray(self):
|
||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
||||
self.assertEqual(bfloat16, x.dtype)
|
||||
self.assertEqual("[[1 2 3]]", str(x))
|
||||
np.testing.assert_equal(x, x)
|
||||
numpy_assert_allclose(x, x)
|
||||
self.assertTrue((x == x).all())
|
||||
|
||||
def testComparisons(self):
|
||||
x = np.array([401408, 7, -32], dtype=np.float32)
|
||||
bx = x.astype(bfloat16)
|
||||
y = np.array([82432, 7, 0], dtype=np.float32)
|
||||
by = y.astype(bfloat16)
|
||||
np.testing.assert_equal(x == y, bx == by)
|
||||
np.testing.assert_equal(x != y, bx != by)
|
||||
np.testing.assert_equal(x < y, bx < by)
|
||||
np.testing.assert_equal(x > y, bx > by)
|
||||
np.testing.assert_equal(x <= y, bx <= by)
|
||||
np.testing.assert_equal(x >= y, bx >= by)
|
||||
|
||||
def testEqual2(self):
|
||||
a = np.array([401408], bfloat16)
|
||||
b = np.array([82432], bfloat16)
|
||||
self.assertFalse(a.__eq__(b))
|
||||
|
||||
def testCasts(self):
|
||||
for dtype in [
|
||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
||||
]:
|
||||
x = np.array([[1, 2, 3]], dtype=dtype)
|
||||
y = x.astype(bfloat16)
|
||||
z = y.astype(dtype)
|
||||
self.assertTrue(np.all(x == y))
|
||||
self.assertEqual(bfloat16, y.dtype)
|
||||
self.assertTrue(np.all(x == z))
|
||||
self.assertEqual(dtype, z.dtype)
|
||||
|
||||
def testConformNumpyComplex(self):
|
||||
for dtype in [np.complex64, np.complex128]:
|
||||
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
|
||||
y_np = x.astype(np.float32)
|
||||
y_tf = x.astype(bfloat16)
|
||||
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
|
||||
|
||||
z_np = y_np.astype(dtype)
|
||||
z_tf = y_tf.astype(dtype)
|
||||
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
|
||||
|
||||
def testArange(self):
|
||||
np.testing.assert_equal(
|
||||
np.arange(100, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(100, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-0., -7., -0.25, dtype=bfloat16))
|
||||
np.testing.assert_equal(
|
||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in UNARY_UFUNCS))
|
||||
def testUnaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_UFUNCS))
|
||||
def testBinaryUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
|
||||
numpy_assert_allclose(
|
||||
op(x, y).astype(np.float32),
|
||||
op(x.astype(np.float32), y.astype(np.float32)),
|
||||
rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in BINARY_PREDICATE_UFUNCS))
|
||||
def testBinaryPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
np.testing.assert_equal(
|
||||
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
|
||||
|
||||
@parameterized.named_parameters(({
|
||||
"testcase_name": "_" + op.__name__,
|
||||
"op": op
|
||||
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
||||
def testPredicateUfunc(self, op):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
shape = (3, 7, 10)
|
||||
posinf_flips = rng.rand(*shape) < 0.1
|
||||
neginf_flips = rng.rand(*shape) < 0.1
|
||||
nan_flips = rng.rand(*shape) < 0.1
|
||||
vals = rng.randn(*shape)
|
||||
vals = np.where(posinf_flips, np.inf, vals)
|
||||
vals = np.where(neginf_flips, -np.inf, vals)
|
||||
vals = np.where(nan_flips, np.nan, vals)
|
||||
vals = vals.astype(bfloat16)
|
||||
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
|
||||
|
||||
def testDivmod(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
||||
o1, o2 = np.divmod(x, y)
|
||||
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
|
||||
numpy_assert_allclose(o1, e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2, e2, rtol=1e-2)
|
||||
|
||||
def testModf(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
o1, o2 = np.modf(x)
|
||||
e1, e2 = np.modf(x.astype(np.float32))
|
||||
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
|
||||
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
|
||||
|
||||
def testLdexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
y = rng.randint(-50, 50, (1, 7))
|
||||
numpy_assert_allclose(
|
||||
np.ldexp(x, y).astype(np.float32),
|
||||
np.ldexp(x.astype(np.float32), y),
|
||||
rtol=1e-2,
|
||||
atol=1e-6)
|
||||
|
||||
def testFrexp(self):
|
||||
rng = np.random.RandomState(seed=42)
|
||||
x = rng.randn(3, 7).astype(bfloat16)
|
||||
mant1, exp1 = np.frexp(x)
|
||||
mant2, exp2 = np.frexp(x.astype(np.float32))
|
||||
np.testing.assert_equal(exp1, exp2)
|
||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||
|
||||
def testNextAfter(self):
|
||||
one = np.array(1., dtype=bfloat16)
|
||||
two = np.array(2., dtype=bfloat16)
|
||||
zero = np.array(0., dtype=bfloat16)
|
||||
nan = np.array(np.nan, dtype=bfloat16)
|
||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||
smallest_denormal = float.fromhex("1.0p-133")
|
||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||
np.testing.assert_equal(
|
||||
np.nextafter(
|
||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||
np.nextafter(
|
||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
@ -214,11 +214,9 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
||||
}
|
||||
|
||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||
const se::Platform* platform =
|
||||
device.local_device_state()->executor()->platform();
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
if (device.client()->platform_id() == kCpuId) {
|
||||
return kDLCPU;
|
||||
} else if (platform->id() == se::cuda::kCudaPlatformId) {
|
||||
} else if (device.client()->platform_id() == kGpuId) {
|
||||
return kDLGPU;
|
||||
}
|
||||
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
||||
@ -228,7 +226,7 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
||||
DLContext context;
|
||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||
context.device_id = device.local_device_id();
|
||||
context.device_id = device.local_hardware_id();
|
||||
return context;
|
||||
}
|
||||
|
||||
@ -241,14 +239,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
||||
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
}
|
||||
return client.LookupLocalDevice(context.device_id);
|
||||
return client.LookupAddressableDevice(context.device_id);
|
||||
case kDLGPU:
|
||||
if (client.platform_id() != kGpuId) {
|
||||
return InvalidArgument(
|
||||
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
}
|
||||
return client.LookupLocalDevice(context.device_id);
|
||||
return client.LookupAddressableDevice(context.device_id);
|
||||
default:
|
||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||
context.device_type);
|
||||
@ -297,7 +295,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
|
||||
pack->tensor.manager_ctx = pack.get();
|
||||
pack->tensor.deleter = DLPackTensorDeleter;
|
||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
||||
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
|
||||
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
|
||||
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
||||
TF_ASSIGN_OR_RETURN(dt.dtype,
|
||||
PrimitiveTypeToDLDataType(
|
||||
|
@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
||||
callback_ = callback;
|
||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||
for (const auto& client : clients) {
|
||||
for (const auto& device : client->devices()) {
|
||||
devices_.push_back(device.get());
|
||||
for (auto device : client->devices()) {
|
||||
devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
CHECK_GT(devices_.size(), 0);
|
||||
@ -342,11 +342,7 @@ StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
||||
const PjRtDevice* device, const Shape& shape) {
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
local_device->client()->TransferFromOutfeedLocal(
|
||||
shape, local_device->device_ordinal()));
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
|
||||
|
||||
return absl::make_unique<Literal>(std::move(literal));
|
||||
}
|
||||
|
@ -86,8 +86,8 @@ StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
||||
}
|
||||
|
||||
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
||||
if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
|
||||
se::PlatformKind::kCuda) {
|
||||
// TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
|
||||
if (buffer_->client()->platform_id() != kGpuId) {
|
||||
return InvalidArgument(
|
||||
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
|
||||
}
|
||||
|
@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||
|
||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||
devices.reserve(pjrt_client_->devices().size());
|
||||
for (const auto& device : pjrt_client_->devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
||||
auto span = pjrt_client_->devices();
|
||||
devices.reserve(span.size());
|
||||
for (PjRtDevice* device : span) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||
result[r].resize(num_partitions);
|
||||
for (int p = 0; p < num_partitions; ++p) {
|
||||
int device_id = device_assignment(r, p);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||
pjrt_client_->LookupDevice(device_id));
|
||||
result[r][p] = WrapWithClient(shared_from_this(), device);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||
pjrt_client_->LookupDevice(device_id));
|
||||
result.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
device = pjrt_client_->local_devices().front();
|
||||
}
|
||||
CHECK(device != nullptr);
|
||||
auto iter = pjrt_client_->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
|
||||
pjrt_client_->LookupDevice(device->id()));
|
||||
if (found_device != device) {
|
||||
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(),
|
||||
pjrt_client_->platform_name());
|
||||
|
@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
const std::string& platform_name() const {
|
||||
return pjrt_client_->platform_name();
|
||||
}
|
||||
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
||||
int addressable_device_count() const {
|
||||
return pjrt_client_->addressable_device_count();
|
||||
}
|
||||
int device_count() const { return pjrt_client_->device_count(); }
|
||||
int host_id() const { return pjrt_client_->host_id(); }
|
||||
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -106,59 +107,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
|
||||
}
|
||||
}
|
||||
|
||||
void PyTreeDef::FlattenInto(py::handle handle,
|
||||
std::vector<py::object>& leaves) {
|
||||
void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
|
||||
absl::optional<py::function> leaf_predicate) {
|
||||
Node node;
|
||||
int start_num_nodes = traversal_.size();
|
||||
int start_num_leaves = leaves.size();
|
||||
node.kind = GetKind(handle, &node.custom);
|
||||
if (node.kind == Kind::kNone) {
|
||||
// Nothing to do.
|
||||
} else if (node.kind == Kind::kTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kList) {
|
||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||
node.arity = list.size();
|
||||
for (py::handle entry : list) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kDict) {
|
||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||
if (PyList_Sort(keys.ptr())) {
|
||||
throw std::runtime_error("Dictionary key sort failed.");
|
||||
}
|
||||
for (py::handle key : keys) {
|
||||
FlattenInto(dict[key], leaves);
|
||||
}
|
||||
node.arity = dict.size();
|
||||
node.node_data = std::move(keys);
|
||||
} else if (node.kind == Kind::kCustom) {
|
||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||
if (out.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
"PyTree custom to_iterable function should return a pair");
|
||||
}
|
||||
node.node_data = out[1];
|
||||
node.arity = 0;
|
||||
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
||||
++node.arity;
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kNamedTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
|
||||
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||
} else {
|
||||
assert(node.kind == Kind::kLeaf);
|
||||
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
|
||||
node.kind = GetKind(handle, &node.custom);
|
||||
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
|
||||
FlattenInto(child, leaves, leaf_predicate);
|
||||
};
|
||||
if (node.kind == Kind::kNone) {
|
||||
// Nothing to do.
|
||||
} else if (node.kind == Kind::kTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
for (py::handle entry : tuple) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kList) {
|
||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||
node.arity = list.size();
|
||||
for (py::handle entry : list) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kDict) {
|
||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||
if (PyList_Sort(keys.ptr())) {
|
||||
throw std::runtime_error("Dictionary key sort failed.");
|
||||
}
|
||||
for (py::handle key : keys) {
|
||||
recurse(dict[key]);
|
||||
}
|
||||
node.arity = dict.size();
|
||||
node.node_data = std::move(keys);
|
||||
} else if (node.kind == Kind::kCustom) {
|
||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||
if (out.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
"PyTree custom to_iterable function should return a pair");
|
||||
}
|
||||
node.node_data = out[1];
|
||||
node.arity = 0;
|
||||
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
||||
++node.arity;
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kNamedTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||
for (py::handle entry : tuple) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else {
|
||||
assert(node.kind == Kind::kLeaf);
|
||||
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||
}
|
||||
}
|
||||
node.num_nodes = traversal_.size() - start_num_nodes + 1;
|
||||
node.num_leaves = leaves.size() - start_num_leaves;
|
||||
@ -166,10 +174,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
|
||||
}
|
||||
|
||||
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
|
||||
PyTreeDef::Flatten(py::handle x) {
|
||||
PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
|
||||
std::vector<py::object> leaves;
|
||||
auto tree = absl::make_unique<PyTreeDef>();
|
||||
tree->FlattenInto(x, leaves);
|
||||
tree->FlattenInto(x, leaves, leaf_predicate);
|
||||
return std::make_pair(std::move(leaves), std::move(tree));
|
||||
}
|
||||
|
||||
@ -618,7 +626,8 @@ std::string PyTreeDef::ToString() const {
|
||||
|
||||
void BuildPytreeSubmodule(py::module& m) {
|
||||
py::module pytree = m.def_submodule("pytree", "Python tree library");
|
||||
pytree.def("flatten", &PyTreeDef::Flatten);
|
||||
pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
|
||||
py::arg("leaf_predicate") = absl::nullopt);
|
||||
pytree.def("tuple", &PyTreeDef::Tuple);
|
||||
pytree.def("all_leaves", &PyTreeDef::AllLeaves);
|
||||
|
||||
|
@ -85,11 +85,13 @@ class PyTreeDef {
|
||||
|
||||
// Flattens a Pytree into a list of leaves and a PyTreeDef.
|
||||
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
|
||||
Flatten(pybind11::handle x);
|
||||
Flatten(pybind11::handle x,
|
||||
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||
|
||||
// Recursive helper used to implement Flatten().
|
||||
void FlattenInto(pybind11::handle handle,
|
||||
std::vector<pybind11::object>& leaves);
|
||||
void FlattenInto(
|
||||
pybind11::handle handle, std::vector<pybind11::object>& leaves,
|
||||
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||
|
||||
// Tests whether the given list is a flat list of leaves.
|
||||
static bool AllLeaves(const pybind11::iterable& x);
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core/framework:allocator",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -37,8 +37,8 @@ namespace xla {
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
||||
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
|
||||
<< "Inserting duplicate replica:" << replica;
|
||||
executables_[replica] =
|
||||
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
||||
addressable_device_logical_ids_.emplace_back(replica, partition);
|
||||
local_logical_device_ids_.emplace_back(replica, partition);
|
||||
local_devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
||||
// long time and we want all cores to be scheduled in parallel.
|
||||
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
||||
&execute_semaphore]() {
|
||||
const int replica = addressable_device_logical_ids_[i].first;
|
||||
const int partition = addressable_device_logical_ids_[i].second;
|
||||
const int replica = local_logical_device_ids_[i].first;
|
||||
const int partition = local_logical_device_ids_[i].second;
|
||||
RunId run_id;
|
||||
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
||||
replica, partition, run_id);
|
||||
|
@ -32,13 +32,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
constexpr char kTpuPlatform[] = "tpu";
|
||||
|
||||
class TpuDevice : public PjRtDevice {
|
||||
class TpuDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip);
|
||||
@ -298,9 +299,8 @@ class PyTpuExecutable {
|
||||
return device_assignment_;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
|
||||
const {
|
||||
return addressable_device_logical_ids_;
|
||||
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
||||
return local_logical_device_ids_;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
||||
@ -341,14 +341,16 @@ class PyTpuExecutable {
|
||||
|
||||
// The replica and partition indices of device_assignment_ to be run by this
|
||||
// client. On single-host platforms without partitioning, this is all replicas
|
||||
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
||||
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
||||
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
||||
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
|
||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
||||
// on multi-host platforms.
|
||||
// If there are 4 replicas and 2 partitions on a single host platform, size of
|
||||
// local_logical_device_ids_ is 4*2 = 8.
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids_;
|
||||
|
||||
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
|
||||
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
|
||||
// Python bindings (see xla.cc).
|
||||
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
|
||||
// assigned.
|
||||
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
||||
// (see xla.cc).
|
||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||
|
||||
xla::Shape result_shape_;
|
||||
|
@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
|
||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||
.def("local_logical_device_ids",
|
||||
&PyTpuExecutable::addressable_device_logical_ids)
|
||||
&PyTpuExecutable::local_logical_device_ids)
|
||||
.def("local_devices", &PyTpuExecutable::local_devices)
|
||||
.def_property_readonly("client", &PyTpuExecutable::client)
|
||||
.def("size_of_generated_code_in_bytes",
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
|
||||
case U64:
|
||||
return py::dtype::of<uint64>();
|
||||
case BF16: {
|
||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
||||
return py::dtype::from_args(bfloat16);
|
||||
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||
return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
|
||||
}
|
||||
case F16:
|
||||
return py::dtype("e"); // PEP 3118 code for "float16
|
||||
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
|
||||
// We requested an array of uint16 since NumPy doesn't know how
|
||||
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
|
||||
// before handing it back to the caller.
|
||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
||||
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||
bfloat16.inc_ref();
|
||||
array = py::reinterpret_steal<py::array>(
|
||||
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
|
||||
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
|
||||
reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
|
||||
static_cast<PyTypeObject*>(nullptr)));
|
||||
}
|
||||
return array;
|
||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||
#include "tensorflow/compiler/xla/python/ops.h"
|
||||
@ -59,6 +58,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
throw std::runtime_error("Unable to initialize Numpy API");
|
||||
}
|
||||
|
||||
CHECK(tensorflow::RegisterNumpyBfloat16());
|
||||
|
||||
// Types
|
||||
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
||||
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
|
||||
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("OPAQUE_TYPE", OPAQUE_TYPE)
|
||||
.value("TOKEN", TOKEN);
|
||||
|
||||
m.def("bfloat16_dtype", Bfloat16Dtype);
|
||||
m.def("bfloat16_dtype",
|
||||
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
|
||||
|
||||
// Must be before PyClient.compile.
|
||||
BuildXlaCompilerSubmodule(m);
|
||||
@ -149,7 +152,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
||||
"Integer ID of this device's host.\n\n"
|
||||
"This is always 0 except on multi-host platforms.")
|
||||
.def_property_readonly("platform", &PjRtDevice::platform_name)
|
||||
.def_property_readonly("platform",
|
||||
[](const PjRtDevice& device) {
|
||||
return device.client()->platform_name();
|
||||
})
|
||||
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
||||
.def_property_readonly(
|
||||
"client",
|
||||
@ -234,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
.def("local_device_count", &PyClient::local_device_count)
|
||||
.def("local_device_count", &PyClient::addressable_device_count)
|
||||
.def("devices", &PyClient::Devices)
|
||||
.def("local_devices", &PyClient::LocalDevices)
|
||||
.def("host_id", &PyClient::host_id)
|
||||
@ -381,10 +387,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
[](PyExecutable* exec) {
|
||||
auto span = exec->addressable_device_logical_ids();
|
||||
// Not on dispatch critical path, so ok to have heap allocation.
|
||||
std::vector<std::pair<int, int>> addressable_device_logical_ids;
|
||||
addressable_device_logical_ids.reserve(span.size());
|
||||
std::vector<std::pair<int, int>> addressable_device_logic_ids;
|
||||
addressable_device_logic_ids.reserve(span.size());
|
||||
for (const auto& logical_device_id : span) {
|
||||
addressable_device_logical_ids.push_back(std::make_pair(
|
||||
addressable_device_logic_ids.push_back(std::make_pair(
|
||||
logical_device_id.replica, logical_device_id.partition));
|
||||
}
|
||||
})
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -158,6 +159,18 @@ class AotCompilationMetadata {
|
||||
// platform.
|
||||
class Compiler {
|
||||
public:
|
||||
struct CompileOptions {
|
||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||
// space on the device for use during compilation. For example, the
|
||||
// compiler may allocate buffers on the device and then run variants of a
|
||||
// given algorithm over those buffers, to see which variant is fastest. Any
|
||||
// space allocated will be deallocated before the compilation returns.
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
|
||||
// An optional thread pool for parallel compilation.
|
||||
tensorflow::thread::ThreadPool* thread_pool = nullptr;
|
||||
};
|
||||
|
||||
virtual ~Compiler() {}
|
||||
|
||||
// Returns the ID of the platform that this compiler targets.
|
||||
@ -165,31 +178,24 @@ class Compiler {
|
||||
|
||||
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
||||
// module.
|
||||
//
|
||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||
// space on the device for use during compilation. For example, the compiler
|
||||
// may allocate buffers on the device and then run variants of a given
|
||||
// algorithm over those buffers, to see which variant is fastest. Any space
|
||||
// allocated should be deallocated before this function returns.
|
||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
const CompileOptions& options) = 0;
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return RunHloPasses(std::move(module), executor,
|
||||
CompileOptions{device_allocator});
|
||||
}
|
||||
|
||||
// Runs HLO passes to optimize the given HloModule, perform scheduling and
|
||||
// buffer assignment, returns the optimized module and the buffer assignments.
|
||||
// This interface is intentionally narrow.
|
||||
//
|
||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||
// space on the device for use during compilation. For example, the compiler
|
||||
// may allocate buffers on the device and then run variants of a given
|
||||
// algorithm over those buffers, to see which variant is fastest. Any space
|
||||
// allocated should be deallocated before this function returns.
|
||||
virtual StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool optimize) {
|
||||
se::StreamExecutor* executor, bool optimize,
|
||||
const CompileOptions& options) {
|
||||
return Unimplemented("This compiler does not support this method");
|
||||
}
|
||||
|
||||
@ -201,24 +207,33 @@ class Compiler {
|
||||
//
|
||||
// The compiler may optionally specialize to the individual device
|
||||
// (not just type of device) indicated by the executor.
|
||||
//
|
||||
// device_allocator is optional; see RunHloPasses.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
const CompileOptions& options) = 0;
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return RunBackend(std::move(module), executor,
|
||||
CompileOptions{device_allocator});
|
||||
}
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules, and returns a corresponding
|
||||
// sequence of executable objects.
|
||||
//
|
||||
// device_allocator is optional; see RunHloPasses.
|
||||
//
|
||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||
// modules to RunHloPasses and RunBackends.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
||||
const CompileOptions& options) = 0;
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return Compile(std::move(module_group), stream_exec,
|
||||
CompileOptions{device_allocator});
|
||||
}
|
||||
|
||||
// Returns the backend configurations that the backend will consider for the
|
||||
// given HLO. Returns no configurations if the backend does not support
|
||||
|
@ -553,7 +553,7 @@ Status CreateHloProfilingArtifacts(
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
const CompileOptions& /*options*/) {
|
||||
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
|
||||
SimpleOrcJIT::InferTargetMachineForJIT(
|
||||
CompilerTargetOptions(module->config()),
|
||||
@ -566,12 +566,13 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
CpuCompiler::RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
|
||||
CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||
se::StreamExecutor* executor,
|
||||
bool optimize,
|
||||
const CompileOptions& options) {
|
||||
if (optimize) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, RunHloPasses(std::move(module), executor, device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(module,
|
||||
RunHloPasses(std::move(module), executor, options));
|
||||
}
|
||||
|
||||
// Select an order for emitting the HLO instructions for each computation.
|
||||
@ -632,7 +633,7 @@ struct OrcJITPostCompilationHook {
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
const CompileOptions& options) {
|
||||
VLOG(1) << "Compiling: " << module->name();
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
|
||||
|
@ -134,18 +134,17 @@ class CpuCompiler : public LLVMCompiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool optimize) override;
|
||||
se::StreamExecutor* executor, bool optimize,
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
|
@ -440,7 +440,7 @@ filegroup(
|
||||
name = "nccl_collective_thunk_src",
|
||||
srcs = if_nccl(
|
||||
["nccl_collective_thunk.cc"],
|
||||
["dummy_collective_thunk.cc"],
|
||||
["nccl_collective_thunk_dummy.cc"],
|
||||
),
|
||||
)
|
||||
|
||||
@ -448,7 +448,7 @@ tf_cuda_library(
|
||||
name = "nccl_collective_thunk",
|
||||
srcs = if_cuda_or_rocm(
|
||||
[":nccl_collective_thunk_src"],
|
||||
["dummy_collective_thunk.cc"],
|
||||
["nccl_collective_thunk_dummy.cc"],
|
||||
),
|
||||
hdrs = ["nccl_collective_thunk.h"],
|
||||
deps = [
|
||||
@ -480,7 +480,7 @@ filegroup(
|
||||
name = "nccl_all_gather_thunk_src",
|
||||
srcs = if_nccl(
|
||||
["nccl_all_gather_thunk.cc"],
|
||||
["dummy_all_gather_thunk.cc"],
|
||||
["nccl_all_gather_thunk_dummy.cc"],
|
||||
),
|
||||
)
|
||||
|
||||
@ -488,7 +488,7 @@ tf_cuda_library(
|
||||
name = "nccl_all_gather_thunk",
|
||||
srcs = if_cuda_or_rocm(
|
||||
[":nccl_all_gather_thunk_src"],
|
||||
["dummy_all_gather_thunk.cc"],
|
||||
["nccl_all_gather_thunk_dummy.cc"],
|
||||
),
|
||||
hdrs = ["nccl_all_gather_thunk.h"],
|
||||
deps = [
|
||||
@ -520,7 +520,7 @@ filegroup(
|
||||
name = "nccl_all_reduce_thunk_src",
|
||||
srcs = if_nccl(
|
||||
["nccl_all_reduce_thunk.cc"],
|
||||
["dummy_all_reduce_thunk.cc"],
|
||||
["nccl_all_reduce_thunk_dummy.cc"],
|
||||
),
|
||||
)
|
||||
|
||||
@ -528,7 +528,7 @@ tf_cuda_library(
|
||||
name = "nccl_all_reduce_thunk",
|
||||
srcs = if_cuda_or_rocm(
|
||||
[":nccl_all_reduce_thunk_src"],
|
||||
["dummy_all_reduce_thunk.cc"],
|
||||
["nccl_all_reduce_thunk_dummy.cc"],
|
||||
),
|
||||
hdrs = ["nccl_all_reduce_thunk.h"],
|
||||
deps = [
|
||||
@ -560,7 +560,7 @@ filegroup(
|
||||
name = "nccl_all_to_all_thunk_src",
|
||||
srcs = if_nccl(
|
||||
["nccl_all_to_all_thunk.cc"],
|
||||
["dummy_all_to_all_thunk.cc"],
|
||||
["nccl_all_to_all_thunk_dummy.cc"],
|
||||
),
|
||||
)
|
||||
|
||||
@ -568,7 +568,7 @@ tf_cuda_library(
|
||||
name = "nccl_all_to_all_thunk",
|
||||
srcs = if_cuda_or_rocm(
|
||||
[":nccl_all_to_all_thunk_src"],
|
||||
["dummy_all_to_all_thunk.cc"],
|
||||
["nccl_all_to_all_thunk_dummy.cc"],
|
||||
),
|
||||
hdrs = ["nccl_all_to_all_thunk.h"],
|
||||
deps = [
|
||||
@ -600,7 +600,7 @@ filegroup(
|
||||
name = "nccl_test_utils_src",
|
||||
srcs = if_nccl(
|
||||
["nccl_test_utils.cc"],
|
||||
["dummy_nccl_test_utils.cc"],
|
||||
["nccl_test_utils_dummy.cc"],
|
||||
),
|
||||
)
|
||||
|
||||
@ -608,7 +608,7 @@ tf_cuda_library(
|
||||
name = "nccl_test_utils",
|
||||
srcs = if_cuda_or_rocm(
|
||||
[":nccl_test_utils_src"],
|
||||
["dummy_nccl_test_utils.cc"],
|
||||
["nccl_test_utils_dummy.cc"],
|
||||
),
|
||||
hdrs = ["nccl_test_utils.h"],
|
||||
deps = [
|
||||
@ -1452,7 +1452,11 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:AsmParser",
|
||||
"@llvm-project//llvm:BitReader",
|
||||
"@llvm-project//llvm:BitWriter",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
@ -1517,7 +1521,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
]),
|
||||
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -108,12 +108,17 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
||||
llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
se::StreamExecutor* stream_exec,
|
||||
bool relocatable) {
|
||||
if (rocdl_dir_.empty()) {
|
||||
// Compute rocdl_dir_ just once and cache it in this member.
|
||||
rocdl_dir_ = GetROCDLDir(module->config());
|
||||
}
|
||||
|
||||
if (relocatable) {
|
||||
return Unimplemented("relocatable target binary is not implemented");
|
||||
}
|
||||
|
||||
std::vector<uint8> hsaco;
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
|
@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler {
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||
bool relocatable) override;
|
||||
|
||||
private:
|
||||
// The parent directory of ROCm-Device-Libs IR libraries.
|
||||
|
@ -24,11 +24,15 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/AsmParser/Parser.h"
|
||||
#include "llvm/Bitcode/BitcodeReader.h"
|
||||
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||
#include "llvm/IR/DiagnosticInfo.h"
|
||||
#include "llvm/IR/DiagnosticPrinter.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/Transforms/Utils/SplitModule.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
@ -114,11 +118,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/subprocess.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
@ -470,14 +476,14 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
// We dump the post-optimization HLO in RunBackend so no need to dump it here.
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
[&] { return absl::StrCat("HLO Transforms:", module->name()); },
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
TF_RETURN_IF_ERROR(
|
||||
OptimizeHloModule(module.get(), stream_exec, device_allocator));
|
||||
OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
|
||||
|
||||
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
@ -494,10 +500,10 @@ StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
GpuCompiler::RunHloPassesAndBufferAssignement(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
|
||||
bool optimize, const CompileOptions& options) {
|
||||
if (optimize) {
|
||||
TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module),
|
||||
executor, device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(hlo_module,
|
||||
RunHloPasses(std::move(hlo_module), executor, options));
|
||||
}
|
||||
|
||||
std::unique_ptr<StreamAssignment> stream_assignment =
|
||||
@ -641,24 +647,133 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
GpuCompiler::CompileToTargetBinary(const HloModule& module,
|
||||
std::unique_ptr<llvm::Module> llvm_module,
|
||||
se::StreamExecutor* stream_exec,
|
||||
const CompileOptions& options) {
|
||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||
|
||||
const auto compile_single_module =
|
||||
[this, stream_exec, &module](
|
||||
llvm::Module* llvm_module,
|
||||
bool relocatable) -> StatusOr<BackendCompileResult> {
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
"GpuCompiler::RunBackend - Running LLVM verifier");
|
||||
|
||||
std::string err;
|
||||
llvm::raw_string_ostream err_stream(err);
|
||||
|
||||
// verifyModule() returns true if the module is broken.
|
||||
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
|
||||
<< "Invalid LLVM IR before optimizations:\n"
|
||||
<< err_stream.str()
|
||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR "
|
||||
"lowering. "
|
||||
"Rerun with --xla_dump_to to get the IR and looks for files "
|
||||
"with "
|
||||
"name containing: *"
|
||||
<< FilenameFor(module, "", "") << "*";
|
||||
}
|
||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||
return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
|
||||
relocatable);
|
||||
};
|
||||
|
||||
tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
|
||||
if (!thread_pool) {
|
||||
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
|
||||
}
|
||||
|
||||
// Test whether LinkModules is supported.
|
||||
if (this->LinkModules(stream_exec, {}).status().code() ==
|
||||
tensorflow::error::Code::UNIMPLEMENTED) {
|
||||
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
|
||||
int num_functions = 0;
|
||||
for (llvm::Function& func : llvm_module->functions()) {
|
||||
if (!func.isDeclaration() &&
|
||||
func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
|
||||
num_functions++;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SplitModule(
|
||||
std::move(llvm_module),
|
||||
std::max<unsigned>(
|
||||
1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
|
||||
[&](std::unique_ptr<llvm::Module> module) {
|
||||
llvm_modules.push_back(std::move(module));
|
||||
},
|
||||
/*PreserveLocals=*/true);
|
||||
|
||||
std::vector<StatusOr<BackendCompileResult>> compile_results(
|
||||
llvm_modules.size());
|
||||
tensorflow::BlockingCounter counter(llvm_modules.size());
|
||||
for (int i = 0; i < llvm_modules.size(); i++) {
|
||||
thread_pool->Schedule([&compile_results, compile_single_module, i,
|
||||
&llvm_modules, &counter] {
|
||||
llvm::Module* original_module = llvm_modules[i].get();
|
||||
llvm::LLVMContext context;
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream error(buffer);
|
||||
llvm::DiagnosticPrinterRawOStream printer(error);
|
||||
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
|
||||
void* Context) {
|
||||
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
|
||||
diag_info.print(*printer);
|
||||
};
|
||||
context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
|
||||
|
||||
std::unique_ptr<llvm::Module> new_llvm_module;
|
||||
{
|
||||
std::string ir;
|
||||
{
|
||||
llvm::raw_string_ostream os(ir);
|
||||
original_module->print(os, nullptr);
|
||||
}
|
||||
llvm::SMDiagnostic err;
|
||||
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
||||
}
|
||||
|
||||
compile_results[i] =
|
||||
compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
|
||||
std::string ptx_snippets;
|
||||
std::vector<std::vector<uint8>> submodule_compile_results;
|
||||
for (auto& maybe_result : compile_results) {
|
||||
TF_ASSIGN_OR_RETURN(auto result, maybe_result);
|
||||
if (result.second.empty()) {
|
||||
continue;
|
||||
}
|
||||
ptx_snippets += result.first;
|
||||
ptx_snippets += "\n";
|
||||
submodule_compile_results.push_back(result.second);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<uint8> backend_result,
|
||||
this->LinkModules(stream_exec, std::move(submodule_compile_results)));
|
||||
|
||||
return std::make_pair(ptx_snippets, backend_result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
|
||||
auto slow_compile_alarm = SlowCompilationAlarm();
|
||||
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
llvm::LLVMContext llvm_context;
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream error(buffer);
|
||||
llvm::DiagnosticPrinterRawOStream printer(error);
|
||||
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
|
||||
void* Context) {
|
||||
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
|
||||
diag_info.print(*printer);
|
||||
};
|
||||
llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
|
||||
|
||||
GpuDeviceInfo gpu_device_info;
|
||||
gpu_device_info.threads_per_block_limit =
|
||||
@ -724,34 +839,16 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
|
||||
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
|
||||
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
|
||||
|
||||
std::string err;
|
||||
llvm::raw_string_ostream err_stream(err);
|
||||
|
||||
// verifyModule() returns true if the module is broken.
|
||||
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
|
||||
<< "Invalid LLVM IR before optimizations:\n"
|
||||
<< err_stream.str()
|
||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
|
||||
"Rerun with --xla_dump_to to get the IR and looks for files with "
|
||||
"name containing: *"
|
||||
<< FilenameFor(*module, "", "") << "*";
|
||||
}
|
||||
|
||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||
|
||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
||||
CompileTargetBinary(module.get(), llvm_module.get(),
|
||||
gpu_version, stream_exec));
|
||||
|
||||
CompileToTargetBinary(*module, std::move(llvm_module),
|
||||
stream_exec, options));
|
||||
if (DumpingEnabledForHloModule(*module)) {
|
||||
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
|
||||
thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
backend_result.first, backend_result.second, gpu_version,
|
||||
std::move(thunk_schedule), std::move(module),
|
||||
|
@ -53,14 +53,13 @@ class GpuCompiler : public LLVMCompiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<
|
||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool optimize) override;
|
||||
se::StreamExecutor* executor, bool optimize,
|
||||
const CompileOptions& options) override;
|
||||
|
||||
Status OptimizeHloModule(HloModule* hlo_module,
|
||||
se::StreamExecutor* stream_exec,
|
||||
@ -84,19 +83,23 @@ class GpuCompiler : public LLVMCompiler {
|
||||
|
||||
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) = 0;
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||
bool relocatable) = 0;
|
||||
|
||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
AotCompilationOptions const& options) override;
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
|
||||
const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
|
||||
se::StreamExecutor* stream_exec, const CompileOptions& options);
|
||||
|
||||
se::Platform::Id PlatformId() const override { return platform_id_; }
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||
@ -116,6 +119,12 @@ class GpuCompiler : public LLVMCompiler {
|
||||
}
|
||||
|
||||
private:
|
||||
virtual StatusOr<std::vector<uint8>> LinkModules(
|
||||
se::StreamExecutor* stream_exec,
|
||||
std::vector<std::vector<uint8>> modules) {
|
||||
return Unimplemented("LinkModules is not implemented.");
|
||||
}
|
||||
|
||||
se::Platform::Id platform_id_;
|
||||
|
||||
// The triple that represents our target.
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -299,7 +300,8 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||
llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
se::StreamExecutor* stream_exec,
|
||||
bool relocatable) {
|
||||
std::pair<int, int> compute_capability =
|
||||
absl::get<std::pair<int, int>>(gpu_version);
|
||||
|
||||
@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||
|
||||
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
|
||||
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
||||
module->config());
|
||||
module->config(), relocatable);
|
||||
|
||||
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
|
||||
std::move(cubin));
|
||||
@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||
|
||||
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
||||
int cc_minor, const HloModuleConfig& hlo_module_config) {
|
||||
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
|
||||
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
@ -361,7 +363,7 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
std::tie(iter, inserted) = compilation_cache_.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(ptx, cc_major, cc_minor),
|
||||
std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
|
||||
std::forward_as_tuple());
|
||||
cache_ptx = &iter->first.ptx;
|
||||
cache_value = &iter->second;
|
||||
@ -375,9 +377,13 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||
if (inserted) {
|
||||
CHECK(!cache_value->compilation_done);
|
||||
if (!ptx.empty()) {
|
||||
StatusOr<std::vector<uint8>> maybe_cubin =
|
||||
se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(),
|
||||
PtxOptsFromConfig(hlo_module_config));
|
||||
auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
|
||||
if (relocatable) {
|
||||
ptxas_config.extra_flags.push_back("-c");
|
||||
}
|
||||
StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
|
||||
stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
|
||||
|
||||
if (maybe_cubin.ok()) {
|
||||
cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
|
||||
VLOG(2) << "Compiled PTX size:" << ptx.size()
|
||||
@ -445,5 +451,17 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||
return cache_value->cubin_data;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
|
||||
se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
|
||||
std::vector<stream_executor::CubinOrPTXImage> images;
|
||||
images.reserve(modules.size());
|
||||
for (auto& module : modules) {
|
||||
images.push_back({"", std::move(module)});
|
||||
}
|
||||
return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
|
||||
stream_exec->implementation()->GpuContextHack()),
|
||||
images);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler {
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||
bool relocatable) override;
|
||||
|
||||
private:
|
||||
StatusOr<std::vector<uint8>> LinkModules(
|
||||
se::StreamExecutor* stream_exec,
|
||||
std::vector<std::vector<uint8>> modules) override;
|
||||
|
||||
tensorflow::mutex mutex_;
|
||||
|
||||
// When compiling an HLO module, we need to find a path to the nvvm libdevice
|
||||
@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler {
|
||||
// compiled cubin. If compilation was unsuccessful, returns an empty vector.
|
||||
std::vector<uint8> CompileGpuAsmOrGetCachedResult(
|
||||
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
||||
int cc_minor, const HloModuleConfig& hlo_module_config);
|
||||
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
|
||||
|
||||
// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
|
||||
// -> cubin so we don't recompile the same ptx twice. This is important for
|
||||
@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler {
|
||||
// If compiling the ptx fails, we return an empty cubin, cross our fingers,
|
||||
// and leave compilation up to the driver.
|
||||
struct CompilationCacheKey {
|
||||
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor)
|
||||
: ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {}
|
||||
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
|
||||
bool relocatable)
|
||||
: ptx(std::move(ptx)),
|
||||
cc_major(cc_major),
|
||||
cc_minor(cc_minor),
|
||||
relocatable(relocatable) {}
|
||||
string ptx;
|
||||
int cc_major;
|
||||
int cc_minor;
|
||||
bool relocatable;
|
||||
};
|
||||
struct CompilationCacheHash {
|
||||
size_t operator()(const CompilationCacheKey& key) const {
|
||||
return tensorflow::Hash64Combine(
|
||||
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major),
|
||||
key.cc_minor);
|
||||
tensorflow::Hash64Combine(
|
||||
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
|
||||
key.cc_major),
|
||||
key.cc_minor),
|
||||
key.relocatable);
|
||||
}
|
||||
};
|
||||
struct CompilationCacheEq {
|
||||
size_t operator()(const CompilationCacheKey& a,
|
||||
const CompilationCacheKey& b) const {
|
||||
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
|
||||
a.ptx == b.ptx;
|
||||
a.ptx == b.ptx && a.relocatable == b.relocatable;
|
||||
}
|
||||
};
|
||||
struct CompilationCacheValue {
|
||||
|
@ -95,7 +95,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
const CompileOptions& /*options*/) {
|
||||
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
|
||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
||||
return std::move(hlo_module);
|
||||
@ -103,7 +103,7 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
const CompileOptions& /*options*/) {
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
VLOG(1) << "Run backend " << hlo_module->name();
|
||||
@ -128,7 +128,7 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
if (module_group->empty()) {
|
||||
return std::vector<std::unique_ptr<Executable>>();
|
||||
}
|
||||
@ -141,12 +141,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
||||
"Unexpected number of StreamExecutor's.");
|
||||
}
|
||||
auto hlo_modules = module_group->ConsumeModules();
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0],
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executable,
|
||||
RunBackend(std::move(module), stream_exec[0][0], device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
|
||||
stream_exec[0][0], options));
|
||||
TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
|
||||
stream_exec[0][0], options));
|
||||
std::vector<std::unique_ptr<Executable>> ret;
|
||||
ret.push_back(std::move(executable));
|
||||
return std::move(ret);
|
||||
|
@ -45,14 +45,14 @@ class InterpreterCompiler : public Compiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
|
@ -24,7 +24,7 @@ namespace xla {
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
// Tensorflow tries to enable the following behaviors in all its threads:
|
||||
//
|
||||
// - Denormals are zero (DAZ): roughly, operations treat denormal floats as
|
||||
@ -48,10 +48,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(modules[i],
|
||||
RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
||||
device_allocator));
|
||||
options.device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
RunBackend(std::move(modules[i]), stream_execs[i][0],
|
||||
device_allocator));
|
||||
options.device_allocator));
|
||||
result.push_back(std::move(executable));
|
||||
}
|
||||
|
||||
|
@ -66,13 +66,14 @@ class LLVMCompiler : public Compiler {
|
||||
// std::unique_ptr<HloModule> module,
|
||||
// se::StreamExecutor* stream_exec,
|
||||
// se::DeviceMemoryAllocator* device_allocator)
|
||||
using Compiler::Compile;
|
||||
using Compiler::RunBackend;
|
||||
using Compiler::RunHloPasses;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
protected:
|
||||
ModuleHook user_pre_optimization_hook_;
|
||||
|
@ -190,11 +190,12 @@ LocalService::CompileExecutables(
|
||||
// single partition computations are built using `BuildExecutables`, fix it,
|
||||
// and remove this special case (provided the performance if similar).
|
||||
if (build_options.num_partitions() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
||||
executor, build_options.device_allocator(),
|
||||
build_options.run_backend_only()));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
BuildExecutable(proto, std::move(module_config),
|
||||
execute_backend_.get(), executor,
|
||||
{build_options.device_allocator(),
|
||||
build_options.compile_thread_pool()},
|
||||
build_options.run_backend_only()));
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
executables.push_back(std::move(executable));
|
||||
return executables;
|
||||
@ -206,10 +207,12 @@ LocalService::CompileExecutables(
|
||||
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
|
||||
executor);
|
||||
|
||||
return BuildExecutables({&proto}, std::move(module_configs),
|
||||
execute_backend_.get(), {executors},
|
||||
build_options.device_allocator(),
|
||||
build_options.run_backend_only());
|
||||
return BuildExecutables(
|
||||
/*module_protos=*/{&proto}, std::move(module_configs),
|
||||
execute_backend_.get(), {executors},
|
||||
Compiler::CompileOptions{build_options.device_allocator(),
|
||||
build_options.compile_thread_pool()},
|
||||
build_options.run_backend_only());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,28 +28,24 @@ bool IsUnimplemented(StatusOr<T>& result) {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
auto result =
|
||||
primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator);
|
||||
const CompileOptions& options) {
|
||||
auto result = primary_->RunHloPasses(module->Clone(), stream_exec, options);
|
||||
if (IsUnimplemented(result)) {
|
||||
VLOG(2) << "RunHloPasses resulted in " << result.status()
|
||||
<< ", falling back to secondary backend";
|
||||
return secondary_->RunHloPasses(std::move(module), stream_exec,
|
||||
device_allocator);
|
||||
return secondary_->RunHloPasses(std::move(module), stream_exec, options);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
auto result =
|
||||
primary_->RunBackend(module->Clone(), stream_exec, device_allocator);
|
||||
const CompileOptions& options) {
|
||||
auto result = primary_->RunBackend(module->Clone(), stream_exec, options);
|
||||
if (IsUnimplemented(result)) {
|
||||
VLOG(2) << "RunBackend resulted in " << result.status()
|
||||
<< ", falling back to secondary backend";
|
||||
return secondary_->RunBackend(std::move(module), stream_exec,
|
||||
device_allocator);
|
||||
return secondary_->RunBackend(std::move(module), stream_exec, options);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -57,7 +53,7 @@ StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
std::vector<std::unique_ptr<Executable>> result;
|
||||
std::vector<std::unique_ptr<HloModule>> modules =
|
||||
module_group->ConsumeModules();
|
||||
@ -67,17 +63,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
||||
return Unimplemented(
|
||||
"Model partitioning not implemented for the failover compiler!");
|
||||
}
|
||||
auto executable = [stream_execs, device_allocator, i,
|
||||
auto executable = [stream_execs, &options, i,
|
||||
this](std::unique_ptr<HloModule> module)
|
||||
-> StatusOr<std::unique_ptr<Executable>> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto processed_module,
|
||||
primary_->RunHloPasses(std::move(module), stream_execs[i][0],
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
primary_->RunBackend(std::move(processed_module), stream_execs[i][0],
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(auto processed_module,
|
||||
primary_->RunHloPasses(std::move(module),
|
||||
stream_execs[i][0], options));
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
primary_->RunBackend(std::move(processed_module),
|
||||
stream_execs[i][0], options));
|
||||
return result;
|
||||
}(modules[i]->Clone());
|
||||
|
||||
@ -85,13 +79,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
||||
VLOG(2) << "Compile resulted in " << executable.status()
|
||||
<< ", falling back to secondary backend";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
modules[i],
|
||||
secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executable,
|
||||
secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0],
|
||||
device_allocator));
|
||||
modules[i], secondary_->RunHloPasses(std::move(modules[i]),
|
||||
stream_execs[i][0], options));
|
||||
TF_ASSIGN_OR_RETURN(executable,
|
||||
secondary_->RunBackend(std::move(modules[i]),
|
||||
stream_execs[i][0], options));
|
||||
}
|
||||
|
||||
if (!executable.ok()) {
|
||||
|
@ -51,16 +51,16 @@ class FailoverCompiler final : public Compiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
|
@ -87,16 +87,16 @@ class MlirCompilerImpl : public MlirCompiler {
|
||||
public:
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
@ -155,12 +155,12 @@ std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
// Until we find a reason to do something different, run the same passes
|
||||
// that the normal GPU backend runs.
|
||||
gpu::NVPTXCompiler xla_compiler;
|
||||
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
|
||||
device_allocator));
|
||||
options.device_allocator));
|
||||
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
return std::move(module);
|
||||
@ -454,7 +454,7 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
// Determine the HLO schedule, which is an ordering of HLO instructions. This
|
||||
// is used by buffer assignment to enable buffer reuse, and the same ordering
|
||||
// must also be used to determine the thunk launch schedule.
|
||||
@ -595,7 +595,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
const CompileOptions& options) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
}
|
||||
|
||||
|
@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
const std::vector<const HloModuleProto*>& module_protos,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
|
||||
const Compiler::CompileOptions& options, bool run_backend_only) {
|
||||
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
||||
|
||||
// Dump computation proto state if flag is set.
|
||||
@ -387,17 +387,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
if (!run_backend_only) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executables,
|
||||
backend->compiler()->Compile(std::move(module_group),
|
||||
std::move(executors), device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
|
||||
std::move(module_group),
|
||||
std::move(executors), options));
|
||||
} else {
|
||||
auto modules = module_group->ConsumeModules();
|
||||
for (std::unique_ptr<HloModule>& module : modules) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(std::move(module), executors[0][0],
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(
|
||||
std::move(module), executors[0][0], options));
|
||||
executables.push_back(std::move(executable));
|
||||
}
|
||||
}
|
||||
@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
|
||||
BuildExecutables(module_protos, std::move(module_configs),
|
||||
execute_backend_.get(), all_executors,
|
||||
/*device_allocator=*/nullptr));
|
||||
{/*device_allocator=*/nullptr}));
|
||||
std::vector<Executable*> executable_ptrs;
|
||||
executable_ptrs.reserve(executables.size());
|
||||
for (const auto& executable : executables) {
|
||||
@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
const HloModuleProto& module_proto,
|
||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
|
||||
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
|
||||
bool run_backend_only) {
|
||||
VLOG(1) << StrFormat(
|
||||
"BuildExecutable on service %p with serialized module proto: %s", this,
|
||||
@ -822,14 +820,13 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||
|
||||
if (!run_backend_only) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
||||
device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
|
||||
std::move(module), executor, options));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(
|
||||
std::move(module), executor, device_allocator));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(std::move(module), executor, options));
|
||||
|
||||
const auto& debug_opts = module_config->debug_options();
|
||||
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
|
||||
@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
||||
BuildExecutable(arg->computation(), std::move(module_config),
|
||||
execute_backend_.get(),
|
||||
execute_backend_->default_stream_executor(),
|
||||
/*device_allocator=*/nullptr));
|
||||
{/*device_allocator=*/nullptr}));
|
||||
|
||||
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
|
||||
|
||||
|
@ -235,8 +235,7 @@ class Service : public ServiceInterface {
|
||||
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
||||
const HloModuleProto& module_proto,
|
||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr,
|
||||
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
|
||||
bool run_backend_only = false);
|
||||
|
||||
// Same as BuildExecutable() above, but builds a list of Executables for the
|
||||
@ -245,8 +244,7 @@ class Service : public ServiceInterface {
|
||||
const std::vector<const HloModuleProto*>& module_protos,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool run_backend_only = false);
|
||||
const Compiler::CompileOptions& options, bool run_backend_only = false);
|
||||
|
||||
// Runs the given executable with the given arguments and register the result
|
||||
// in the allocation tracker. The handle of the result from the tracker is
|
||||
|
@ -102,27 +102,8 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a sharding where each tuple element is chosen as the more specific
|
||||
// one of the corresponding elements in a and b. Requires a an b to have the
|
||||
// same tuple nesting.
|
||||
HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
|
||||
const HloSharding& b) {
|
||||
if (a.IsTuple()) {
|
||||
HloSharding result = a;
|
||||
CHECK(b.IsTuple());
|
||||
CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size());
|
||||
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
|
||||
result.tuple_elements()[i] = MergeForMoreSpecificSharding(
|
||||
a.tuple_elements()[i], b.tuple_elements()[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return IsShardingMoreSpecific(a, b) ? a : b;
|
||||
}
|
||||
|
||||
// Tries to refine `to_merge` by combining with `old`. Returns if the final
|
||||
// `to_merge` is more specific than `old`. May combine partial sharding in
|
||||
// addition to MergeForMoreSpecificSharding().
|
||||
// `to_merge` is more specific than `old`.
|
||||
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
||||
bool may_combine_partial_sharding) {
|
||||
if (old.IsTuple()) {
|
||||
@ -1093,8 +1074,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
}
|
||||
auto sharding = instruction->operand(0)->sharding();
|
||||
if (instruction->has_sharding()) {
|
||||
sharding =
|
||||
MergeForMoreSpecificSharding(sharding, instruction->sharding());
|
||||
MergeSharding(instruction->sharding(), &sharding,
|
||||
may_combine_partial_sharding);
|
||||
}
|
||||
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
|
||||
may_combine_partial_sharding);
|
||||
@ -1320,6 +1301,12 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
return hlo_sharding_util::ReshapeSharding(
|
||||
user.shape(), instruction.shape(), user.sharding());
|
||||
}
|
||||
case HloOpcode::kPad: {
|
||||
if (&instruction != user.operand(0)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return user.sharding();
|
||||
}
|
||||
case HloOpcode::kSlice: {
|
||||
return user.sharding();
|
||||
}
|
||||
@ -1673,8 +1660,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
// If instruction is a while, or the root or a parameter of a while body,
|
||||
// then propagate its sharding to the while instruction, to its body root,
|
||||
// and to its condition parameter.
|
||||
std::function<void(HloInstruction*)> maybe_computation_propagation =
|
||||
[&](HloInstruction* instruction) {
|
||||
std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
|
||||
maybe_computation_propagation = [&](HloInstruction* instruction,
|
||||
absl::flat_hash_set<HloInstruction*>*
|
||||
changed) {
|
||||
auto propagate_to_instruction = [&](HloInstruction* search_inst) {
|
||||
auto related_instructions = get_related_instructions(search_inst);
|
||||
if (absl::c_count(related_instructions, instruction)) {
|
||||
@ -1683,7 +1672,8 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
inst->sharding() != instruction->sharding()) {
|
||||
VLOG(2) << "Add computation sharding: " << inst->name();
|
||||
inst->set_sharding(instruction->sharding());
|
||||
maybe_computation_propagation(inst);
|
||||
changed->insert(inst);
|
||||
maybe_computation_propagation(inst, changed);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1785,6 +1775,14 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
for (const HloInstruction* instruction : instructions) {
|
||||
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
||||
}
|
||||
auto clear_cache = [&](HloInstruction* hlo) {
|
||||
for (auto operand : hlo->operands()) {
|
||||
already_inferred_from_users.erase(operand);
|
||||
}
|
||||
for (auto user : hlo->users()) {
|
||||
already_inferred_from_operands.erase(user);
|
||||
}
|
||||
};
|
||||
// First iterate the HLO graph in post order taking shardings from
|
||||
// operands.
|
||||
for (HloInstruction* instruction : instructions) {
|
||||
@ -1799,12 +1797,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
any_changed = true;
|
||||
VLOG(2) << "Add sharding (forward-pass): "
|
||||
<< instruction->ToString();
|
||||
maybe_computation_propagation(instruction);
|
||||
for (auto operand : instruction->operands()) {
|
||||
already_inferred_from_users.erase(operand);
|
||||
}
|
||||
for (auto user : instruction->users()) {
|
||||
already_inferred_from_operands.erase(user);
|
||||
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
|
||||
maybe_computation_propagation(instruction, &changed_in_comp_prop);
|
||||
clear_cache(instruction);
|
||||
for (auto hlo : changed_in_comp_prop) {
|
||||
clear_cache(hlo);
|
||||
}
|
||||
changed_last_iter = true;
|
||||
}
|
||||
@ -1823,12 +1820,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
++inferred_from_user_counter;
|
||||
any_changed = true;
|
||||
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
||||
maybe_computation_propagation(*it);
|
||||
for (auto operand : (*it)->operands()) {
|
||||
already_inferred_from_users.erase(operand);
|
||||
}
|
||||
for (auto user : (*it)->users()) {
|
||||
already_inferred_from_operands.erase(user);
|
||||
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
|
||||
maybe_computation_propagation(*it, &changed_in_comp_prop);
|
||||
clear_cache(*it);
|
||||
for (auto hlo : changed_in_comp_prop) {
|
||||
clear_cache(hlo);
|
||||
}
|
||||
changed_last_iter = true;
|
||||
}
|
||||
|
@ -514,6 +514,26 @@ ENTRY %pad {
|
||||
op::Sharding("{devices=[2,2]0,1,2,3}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, PadBackwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
ENTRY %pad {
|
||||
%input = f32[11,17]{1,0} parameter(0)
|
||||
%copy = f32[11,17]{1,0} copy(%input)
|
||||
%pad_value = f32[] parameter(1)
|
||||
%pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2,
|
||||
sharding={devices=[2,2]0,1,2,3}
|
||||
ROOT %result = f32[27,51]{1,0} copy(%pad)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "copy"),
|
||||
op::Sharding("{devices=[2,2]0,1,2,3}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -856,40 +876,41 @@ TEST_F(ShardingPropagationTest, While) {
|
||||
HloModule module
|
||||
|
||||
%cond {
|
||||
%vars.cond = (u32[], f32[10]{0}) parameter(0)
|
||||
%count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0
|
||||
%vars.cond = (u32[], f32[10,10]) parameter(0)
|
||||
%count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0
|
||||
%limit = u32[] constant(10)
|
||||
ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
|
||||
}
|
||||
|
||||
%body {
|
||||
%vars = (u32[], f32[10]{0}) parameter(0)
|
||||
%vars = (u32[], f32[10,10]) parameter(0)
|
||||
%count = u32[] get-tuple-element(%vars), index=0
|
||||
%acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1
|
||||
%acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1
|
||||
|
||||
%one = u32[] constant(1)
|
||||
%count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
|
||||
%acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc)
|
||||
ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1)
|
||||
%acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc)
|
||||
ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%p0 = f32[10]{0} parameter(0)
|
||||
%p0.copy = f32[10]{0} copy(f32[10]{0} %p0)
|
||||
%p1 = f32[10]{0} parameter(1)
|
||||
%p0 = f32[10,10] parameter(0)
|
||||
%p0.copy = f32[10,10] copy(f32[10,10] %p0)
|
||||
%p1 = f32[10,10] parameter(1)
|
||||
%zero = u32[] constant(0)
|
||||
%init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy)
|
||||
%while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init),
|
||||
%init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy)
|
||||
%while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init),
|
||||
body=%body, condition=%cond
|
||||
%res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1
|
||||
%prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1
|
||||
%res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev)
|
||||
ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1)
|
||||
%res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1
|
||||
%prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1
|
||||
%res.1 = f32[10,10] multiply(f32[10,10] %res, %prev)
|
||||
ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1)
|
||||
})";
|
||||
|
||||
auto while_is_sharded = [this](HloModule* module,
|
||||
const HloSharding& sharding) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation(/*is_spmd=*/true).Run(module));
|
||||
EXPECT_TRUE(changed);
|
||||
auto while_instr = FindInstruction(module, "while");
|
||||
EXPECT_NE(nullptr, while_instr);
|
||||
@ -911,7 +932,7 @@ ENTRY %entry {
|
||||
auto body_root = FindInstruction(module.get(), "tuple");
|
||||
EXPECT_NE(nullptr, body_root);
|
||||
auto sharding =
|
||||
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie();
|
||||
ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie();
|
||||
body_root->set_sharding(sharding);
|
||||
while_is_sharded(module.get(), sharding);
|
||||
}
|
||||
@ -921,11 +942,30 @@ ENTRY %entry {
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto acc_1 = FindInstruction(module.get(), "acc.1");
|
||||
EXPECT_NE(nullptr, acc_1);
|
||||
acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie());
|
||||
acc_1->set_sharding(
|
||||
ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie());
|
||||
|
||||
while_is_sharded(
|
||||
module.get(),
|
||||
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie());
|
||||
while_is_sharded(module.get(),
|
||||
ParseSharding("{{replicated}, {devices=[2,1]0,1}}")
|
||||
.ConsumeValueOrDie());
|
||||
}
|
||||
{
|
||||
// Merge partial sharding from operand and body.
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto acc_1 = FindInstruction(module.get(), "acc.1");
|
||||
EXPECT_NE(nullptr, acc_1);
|
||||
acc_1->set_sharding(
|
||||
ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")
|
||||
.ConsumeValueOrDie());
|
||||
auto p0 = FindInstruction(module.get(), "p0");
|
||||
p0->set_sharding(
|
||||
ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")
|
||||
.ConsumeValueOrDie());
|
||||
|
||||
while_is_sharded(module.get(),
|
||||
ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}")
|
||||
.ConsumeValueOrDie());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,6 +182,10 @@ class ConvolutionVisitor {
|
||||
return permute_dims[id];
|
||||
}
|
||||
|
||||
int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
|
||||
return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
|
||||
}
|
||||
|
||||
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
|
||||
HloInstruction* instr, int64 depth);
|
||||
|
||||
@ -215,9 +219,10 @@ class ConvolutionVisitor {
|
||||
// Limit on batch size to apply this technique on.
|
||||
int64 limit_on_batch_size_;
|
||||
|
||||
// We choose the new batch size to be a constant so that space-to-batch
|
||||
// propagation through several convolutional layers is consistent.
|
||||
static constexpr int64 kNewBatchSize = 8;
|
||||
// We choose the new batch size to be kNumSplits times that of the old batch
|
||||
// so that space-to-batch propagation through several convolutional layers is
|
||||
// consistent.
|
||||
static constexpr int64 kNumSplits = 8;
|
||||
|
||||
// Depth for searching reduce window
|
||||
static constexpr int64 kReduceWindowSearchDepth = 10;
|
||||
@ -301,17 +306,12 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
|
||||
if (old_batch_size > limit_on_batch_size_) {
|
||||
return false;
|
||||
}
|
||||
// We currently only cater to evenly divisible cases.
|
||||
if (kNewBatchSize % old_batch_size != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
// If the ratio is not within the 2X range, we can't Halo Pad from the next
|
||||
// split.
|
||||
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
|
||||
if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
|
||||
return false;
|
||||
}
|
||||
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
|
||||
@ -323,6 +323,24 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
|
||||
int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
|
||||
int64 high_padding, int64 halo_size, int64 original_split_dim_size,
|
||||
HloInstruction* pad_val) {
|
||||
const int64 original_batch_size =
|
||||
activations->shape().dimensions(activations_batch_dim) / kNumSplits;
|
||||
|
||||
if (original_batch_size > 1) {
|
||||
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
|
||||
activations->shape().dimensions().end());
|
||||
new_dimensions[activations_batch_dim] = kNumSplits;
|
||||
new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
|
||||
original_batch_size);
|
||||
|
||||
// Reshape the output of the new conv into the old convolutions shape.
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
MakeReshapeHlo(new_dimensions, activations));
|
||||
|
||||
spatial_dimension_to_split++;
|
||||
activations_batch_dim++;
|
||||
}
|
||||
|
||||
const int64 rank = activations->shape().rank();
|
||||
const int64 spatial_split_size =
|
||||
activations->shape().dimensions(spatial_dimension_to_split);
|
||||
@ -415,6 +433,21 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
|
||||
TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
|
||||
spatial_dimension_to_split));
|
||||
}
|
||||
|
||||
if (original_batch_size > 1) {
|
||||
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
|
||||
activations->shape().dimensions().end());
|
||||
new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
|
||||
new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
|
||||
|
||||
// Reshape the output of the new conv into the old convolutions shape.
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
MakeReshapeHlo(new_dimensions, activations));
|
||||
|
||||
spatial_dimension_to_split++;
|
||||
activations_batch_dim++;
|
||||
}
|
||||
|
||||
VLOG(1) << "HaloDuplicated activations " << activations->ToString();
|
||||
return activations;
|
||||
}
|
||||
@ -424,17 +457,20 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
||||
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
||||
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
||||
bool is_backprop) {
|
||||
std::vector<int64> transpose_dims;
|
||||
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
|
||||
if (spatial_dimension_to_split != activations_batch_dim + 1) {
|
||||
std::vector<int64> transpose_dims(activations->shape().rank());
|
||||
if (spatial_dimension_to_split == activations_batch_dim + 1) {
|
||||
absl::c_iota(transpose_dims, 0);
|
||||
} else {
|
||||
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
|
||||
int64 pushed_counter = 0;
|
||||
int64 new_batch_dim, new_spatial_dim;
|
||||
int64 dim_counter = 0;
|
||||
for (int i = 0; i < activations->shape().rank(); ++i) {
|
||||
if (i == activations_batch_dim) {
|
||||
continue;
|
||||
}
|
||||
if (i == spatial_dimension_to_split) {
|
||||
transpose_dims.push_back(activations_batch_dim);
|
||||
transpose_dims[dim_counter++] = activations_batch_dim;
|
||||
new_batch_dim = pushed_counter;
|
||||
pushed_counter++;
|
||||
new_spatial_dim = pushed_counter;
|
||||
@ -452,7 +488,7 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
||||
}
|
||||
}
|
||||
}
|
||||
transpose_dims.push_back(i);
|
||||
transpose_dims[dim_counter++] = i;
|
||||
pushed_counter++;
|
||||
}
|
||||
|
||||
@ -460,14 +496,14 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
||||
spatial_dimension_to_split = new_spatial_dim;
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
MakeTransposeHlo(activations, transpose_dims));
|
||||
}
|
||||
|
||||
if (is_backprop) {
|
||||
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
||||
} else {
|
||||
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
||||
if (is_backprop) {
|
||||
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
||||
} else {
|
||||
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
||||
}
|
||||
dim_numbers = new_dim_numbers;
|
||||
}
|
||||
dim_numbers = new_dim_numbers;
|
||||
|
||||
return SpaceNextToBatchDetails{activations, transpose_dims};
|
||||
}
|
||||
@ -586,12 +622,23 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
VLOG(1) << "Checking if conv is supported for propagation "
|
||||
<< consumer->ToString();
|
||||
if (IsConvSuitableForSpaceToBatch(consumer)) {
|
||||
for (int64 i = 0; i < consumer->operand_count(); ++i) {
|
||||
auto old_producer = consumer->mutable_operand(i);
|
||||
if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
|
||||
return false;
|
||||
}
|
||||
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
|
||||
return false;
|
||||
}
|
||||
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
|
||||
// Make sure that the space dimension is the same across the producer
|
||||
// and consumer.
|
||||
if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
|
||||
get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
|
||||
return false;
|
||||
}
|
||||
// Make sure that the batch dimension is the same across the producer
|
||||
// and consumer.
|
||||
if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
|
||||
dim_map_val_op_0.first) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -611,13 +658,35 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
VLOG(2) << "Checking for backprop filter conv operands "
|
||||
<< consumer->operand_count();
|
||||
|
||||
if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
|
||||
auto activations = consumer->mutable_operand(0);
|
||||
auto kernel = consumer->mutable_operand(1);
|
||||
|
||||
if (!old_to_new_instrs_.contains(kernel)) {
|
||||
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
||||
"kernel is not space-to-batched";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
|
||||
if (!old_to_new_instrs_.contains(activations)) {
|
||||
const int64 lhs_batch = activations->shape().dimensions(
|
||||
consumer->convolution_dimension_numbers().input_feature_dimension());
|
||||
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
|
||||
const int64 old_batch_dim = dim_map_val_op_1.first;
|
||||
auto second_operand = old_to_new_instrs_[kernel];
|
||||
auto permute_dims_second_operand =
|
||||
instr_to_dim_permute_map_[second_operand];
|
||||
const int64 new_batch_dim =
|
||||
DimLookUp(permute_dims_second_operand, old_batch_dim);
|
||||
const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
|
||||
|
||||
// Because we want to convert activations into a space-to-batched version
|
||||
// only for backprop filter convolutions, we want to make sure that the
|
||||
// batch dimensions (feature dimensions, technically) are same sized.
|
||||
// Since RHS is already space-to-batched, we need to account for it too.
|
||||
if (rhs_batch != kNumSplits * lhs_batch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If activations have not been propagated through, we can do
|
||||
// space-to-batch on them provided kernel has been propagated.
|
||||
VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
|
||||
@ -625,10 +694,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
return true;
|
||||
}
|
||||
|
||||
auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
|
||||
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
|
||||
auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
|
||||
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
|
||||
auto first_operand = old_to_new_instrs_[activations];
|
||||
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
|
||||
auto second_operand = old_to_new_instrs_[kernel];
|
||||
auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
|
||||
|
||||
auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
|
||||
auto permute_dims_second_operand =
|
||||
@ -1119,7 +1188,7 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
||||
|
||||
Window new_win;
|
||||
for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
|
||||
auto dim = DimLookUp(permute_dims, i);
|
||||
auto dim = ReverseDimLookUp(permute_dims, i);
|
||||
new_win.add_dimensions();
|
||||
new_win.mutable_dimensions(i)->set_stride(
|
||||
consumer->window().dimensions(dim).stride());
|
||||
@ -1339,7 +1408,9 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
|
||||
const int64 new_space_size = new_shape.dimensions(new_space_dim);
|
||||
const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
|
||||
const int64 old_space_size = old_shape.dimensions(old_space_dim);
|
||||
CHECK_EQ(new_batch_size % old_batch_size, 0);
|
||||
CHECK_EQ(new_batch_size % old_batch_size, 0)
|
||||
<< " New batch size " << new_batch_size << " old batch size "
|
||||
<< old_batch_size;
|
||||
const int64 num_splits = new_batch_size / old_batch_size;
|
||||
// Build a constant PRED to decide which elements in the split dimension
|
||||
// are from halo.
|
||||
@ -1394,8 +1465,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
|
||||
CHECK(old_to_new_instrs_.contains(old_instr));
|
||||
auto new_instr = old_to_new_instrs_[old_instr];
|
||||
VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
|
||||
<< old_space_dim << " new_instr " << new_instr->ToString()
|
||||
<< " permute dims " << instr_to_dim_permute_map_.count(new_instr);
|
||||
<< old_space_dim << " old_instr " << old_instr->ToString()
|
||||
<< "\n new_instr " << new_instr->ToString() << " permute dims "
|
||||
<< instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
|
||||
<< old_batch_size;
|
||||
CHECK(instr_to_dim_permute_map_.contains(new_instr));
|
||||
auto permute_dims = instr_to_dim_permute_map_[new_instr];
|
||||
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
|
||||
@ -1565,6 +1638,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
c.spatial_dimension_to_split, activations_batch_dim));
|
||||
activations_new = retval.instr;
|
||||
std::vector<int64> trans_dims = retval.transpose_dims;
|
||||
CHECK(!trans_dims.empty());
|
||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||
|
||||
@ -1578,8 +1652,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
|
||||
const int64 num_splits = kNumSplits;
|
||||
const int64 output_offsets = convolution->shape().dimensions(
|
||||
permuted_conv_dims_numbers.output_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
@ -1614,6 +1687,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
activations_new->shape().dimensions().end());
|
||||
const int64 reshaped_space_size =
|
||||
new_space_size * new_batch_size / old_batch_size;
|
||||
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
|
||||
<< new_batch_size << " old_batch_size " << old_batch_size;
|
||||
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
|
||||
new_dimensions[activations_batch_dim] = old_batch_size;
|
||||
|
||||
@ -1621,10 +1696,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
|
||||
MakeReshapeHlo(new_dimensions, activations_new));
|
||||
|
||||
VLOG(3) << "First reshape done";
|
||||
PaddingConfig padding_config =
|
||||
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_high(spatial_split_size * new_batch_size -
|
||||
->set_edge_padding_high(spatial_split_size * new_batch_size /
|
||||
old_batch_size -
|
||||
reshaped_space_size);
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_low(0);
|
||||
@ -1647,6 +1724,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
reshaped_activations,
|
||||
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
|
||||
|
||||
VLOG(3) << "Second reshape done";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
HaloDuplicateWithSlice(
|
||||
@ -1664,6 +1743,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
// additional space available, and adjust the required slice size (and
|
||||
// thereby the halo size).
|
||||
if (spatial_split_size < new_space_size) {
|
||||
VLOG(3) << "Decreasing the spatial size while propagating";
|
||||
const int64 additional_space_present = spatial_split_size % c.stride;
|
||||
spatial_split_size = new_space_size;
|
||||
slice_size =
|
||||
@ -1758,6 +1838,7 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
|
||||
|
||||
activations = retval.instr;
|
||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||
CHECK(!transpose_dims.empty());
|
||||
// Because we are splitting the spatial dimension, if convolution needed
|
||||
// padding in the spatial dimension, we materialize it.
|
||||
if (high_padding || low_padding) {
|
||||
@ -1774,7 +1855,9 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
|
||||
MakePadHlo(activations, padding, padding_config));
|
||||
}
|
||||
VLOG(1) << "Initial padded activations shape "
|
||||
<< activations->shape().ToString();
|
||||
<< activations->shape().ToString() << " old_batch_size "
|
||||
<< old_batch_size << " activations_batch_dim "
|
||||
<< activations_batch_dim;
|
||||
|
||||
// Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
|
||||
// and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
|
||||
@ -1829,7 +1912,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
CHECK(old_to_new_instrs_.contains(kernel_old));
|
||||
auto kernel_new = old_to_new_instrs_[kernel_old];
|
||||
|
||||
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
|
||||
|
||||
HloInstruction* activations_new = nullptr;
|
||||
bool activations_locally_space_to_batched = false;
|
||||
// If activations were no space-to-batched, we space-to-batch them below.
|
||||
if (!old_to_new_instrs_.contains(activations_old)) {
|
||||
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
|
||||
@ -1838,28 +1924,34 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
instr_to_dim_map_[activations_old] =
|
||||
std::make_pair(prev_feature_dim, prev_batch_dim);
|
||||
|
||||
int64 activations_batch_dim = original_conv_dims.input_feature_dimension();
|
||||
const int64 old_batch_size =
|
||||
activations_old->shape().dimensions(activations_batch_dim);
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
const int64 new_kernel_space_dim =
|
||||
DimLookUp(permute_dims_kernel, kernel_space_dim);
|
||||
|
||||
const int64 new_kernel_split_dim_size =
|
||||
kernel_new->shape().dimensions(kernel_space_dim);
|
||||
kernel_new->shape().dimensions(new_kernel_space_dim);
|
||||
const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
|
||||
const int64 pad_size =
|
||||
needed_spatial_size * num_splits - old_split_dim_size;
|
||||
needed_spatial_size * kNumSplits - old_split_dim_size;
|
||||
ConvolutionDimensionNumbers tmp_dim_numbers;
|
||||
tmp_dim_numbers = original_conv_dims;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto retval,
|
||||
SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
|
||||
activations_batch_dim,
|
||||
old_batch_dim,
|
||||
/*high_padding=*/pad_size, /*low_padding=*/0,
|
||||
needed_spatial_size, num_splits, /*is_backprop=*/true));
|
||||
needed_spatial_size, kNumSplits, /*is_backprop=*/true));
|
||||
|
||||
old_to_new_instrs_[activations_old] = retval.first;
|
||||
instr_to_dim_permute_map_[retval.first] = retval.second;
|
||||
|
||||
VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString();
|
||||
std::vector<int64> reversed_transpose_dims(retval.second.size());
|
||||
for (int64 i = 0; i < retval.second.size(); ++i) {
|
||||
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
|
||||
}
|
||||
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
|
||||
|
||||
VLOG(3) << "New Activations " << retval.first->ToString();
|
||||
|
||||
activations_locally_space_to_batched = true;
|
||||
}
|
||||
|
||||
CHECK(old_to_new_instrs_.contains(activations_old));
|
||||
@ -1884,7 +1976,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
i, DimLookUp(permute_dims,
|
||||
original_conv_dims.input_spatial_dimensions(i)));
|
||||
permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
|
||||
i, DimLookUp(permute_dims,
|
||||
i, DimLookUp(permute_dims_kernel,
|
||||
original_conv_dims.kernel_spatial_dimensions(i)));
|
||||
}
|
||||
|
||||
@ -1905,10 +1997,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
|
||||
|
||||
const int64 kernel_input_feature_dim = DimLookUp(
|
||||
permute_dims, original_conv_dims.kernel_input_feature_dimension());
|
||||
permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
|
||||
|
||||
const int64 kernel_output_feature_dim = DimLookUp(
|
||||
permute_dims, original_conv_dims.kernel_output_feature_dimension());
|
||||
const int64 kernel_output_feature_dim =
|
||||
DimLookUp(permute_dims_kernel,
|
||||
original_conv_dims.kernel_output_feature_dimension());
|
||||
|
||||
permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
|
||||
kernel_input_feature_dim);
|
||||
@ -1931,7 +2024,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
|
||||
VLOG(1) << "Propagating on conv activations_batch_dim "
|
||||
<< activations_batch_dim << " spatial_dimension_to_split "
|
||||
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size;
|
||||
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size
|
||||
<< " new_split_dim_size " << new_split_dim_size;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto retval,
|
||||
@ -1939,6 +2033,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
spatial_dimension_to_split, activations_batch_dim,
|
||||
/*is_backprop=*/true));
|
||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||
CHECK(!transpose_dims.empty());
|
||||
activations_new = retval.instr;
|
||||
|
||||
VLOG(1) << "Activations_new post BringSpaceNextToBatch "
|
||||
@ -1949,13 +2044,15 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||
|
||||
// Select activations correctly by masking additional space.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
SelectValidPortion(activations_new, activations_old, select_val,
|
||||
activations_batch_dim, spatial_dimension_to_split,
|
||||
old_batch_dim, old_space_dim));
|
||||
|
||||
if (!activations_locally_space_to_batched) {
|
||||
// Select activations correctly by masking additional space.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
SelectValidPortion(activations_new, activations_old, select_val,
|
||||
activations_batch_dim, spatial_dimension_to_split,
|
||||
old_batch_dim, old_space_dim));
|
||||
}
|
||||
VLOG(3) << "Selecting the valid kernel area";
|
||||
// Select kernel correctly by masking additional space.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
kernel_new,
|
||||
@ -2238,7 +2335,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
auto original_conv = convolution;
|
||||
|
||||
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
|
||||
@ -2246,13 +2342,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
const int64 output_offsets =
|
||||
convolution->shape().dimensions(output_spatial_dim);
|
||||
const int64 output_offsets_per_split =
|
||||
CeilOfRatio(output_offsets, num_splits);
|
||||
CeilOfRatio(output_offsets, kNumSplits);
|
||||
|
||||
int64 spatial_split_size =
|
||||
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
||||
// Keep increasing the split size so that overall size isn't smaller than the
|
||||
// original spatial dimension.
|
||||
while (spatial_split_size * num_splits - c.spatial_size < 0) {
|
||||
while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
|
||||
@ -2276,12 +2372,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
const int64 slice_size = spatial_split_size + c.halo_size;
|
||||
|
||||
// Pad spatial dim.
|
||||
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
|
||||
const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
|
||||
|
||||
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
|
||||
<< c.stride << " slice_size " << slice_size;
|
||||
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
|
||||
<< " num_splits " << num_splits << " kernel_spatial_dim_size "
|
||||
<< " num_splits " << kNumSplits << " kernel_spatial_dim_size "
|
||||
<< c.kernel_spatial_dim_size;
|
||||
int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -2292,7 +2388,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
/*low_padding=*/c.base_dilation_factor == 1
|
||||
? c.inherent_low_padding
|
||||
: 0,
|
||||
spatial_split_size, num_splits));
|
||||
spatial_split_size, kNumSplits));
|
||||
HloInstruction* batch_increased_reshape = retval.first;
|
||||
convolution->SetupDerivedInstruction(batch_increased_reshape);
|
||||
|
||||
|
@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, InferThroughPad) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
|
||||
PaddingConfig padding_config;
|
||||
padding_config.add_dimensions()->set_edge_padding_high(1);
|
||||
// After pad the value is [constant, constant, parameter].
|
||||
auto pad = Pad(operand1, parameter, padding_config);
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
|
||||
EXPECT_TRUE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler {
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||
bool relocatable) {
|
||||
if (user_post_optimization_hook_) {
|
||||
user_post_optimization_hook_(*llvm_module);
|
||||
}
|
||||
|
@ -20,9 +20,8 @@ op {
|
||||
"A `Tensor` of type T. An alias of `x`. The content "
|
||||
"of `y` is undefined if there are duplicates in `i`."
|
||||
}
|
||||
summary: <<END
|
||||
Adds v into specified rows of x.
|
||||
|
||||
summary: "Adds v into specified rows of x."
|
||||
description: <<END
|
||||
Computes y = x; y[i, :] += v; return y.
|
||||
END
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "TopKUnique"
|
||||
summary: "Returns the TopK unique values in the array in sorted order. The"
|
||||
summary: "Returns the TopK unique values in the array in sorted order."
|
||||
description: <<END
|
||||
running time is proportional to the product of K and the input
|
||||
The running time is proportional to the product of K and the input
|
||||
size. Sorting the whole array is more efficient for sufficiently large
|
||||
values of K. The median-of-medians algorithm is probably faster, but
|
||||
difficult to implement efficiently in XLA. If there are fewer than K
|
||||
|
@ -1,10 +1,11 @@
|
||||
op {
|
||||
graph_op_name: "TopKWithUnique"
|
||||
summary: "Returns the TopK values in the array in sorted order. This is a combination"
|
||||
summary: "Returns the TopK values in the array in sorted order."
|
||||
description: <<END
|
||||
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
||||
replaced by iota, thus it will be close to the original value but not exactly
|
||||
the same. The running time is proportional to the product of K and the input
|
||||
size. NaNs are never returned. Subnormal numbers are flushed to zero.
|
||||
This is a combination of MakeUnique and TopKUnique. The returned top-K will
|
||||
have its lower bits replaced by iota, thus it will be close to the original
|
||||
value but not exactly the same. The running time is proportional to the product
|
||||
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
|
||||
to zero.
|
||||
END
|
||||
}
|
||||
|
@ -705,10 +705,10 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::AddFunctionDefWithDebugInfo(
|
||||
const FunctionDef& fdef, const Graph* graph_with_debug_info) {
|
||||
Status EagerContext::AddFunctionDefWithStackTraces(
|
||||
const FunctionDef& fdef, const StackTracesMap& stack_traces) {
|
||||
return AddFunctionDef(fdef, FunctionDefLibrary(),
|
||||
/* add_to_local_only=*/false, graph_with_debug_info);
|
||||
/* add_to_local_only=*/false, stack_traces);
|
||||
}
|
||||
|
||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
||||
@ -719,7 +719,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
||||
const FunctionDefLibrary& library,
|
||||
const bool add_to_local_only,
|
||||
const Graph* graph_with_debug_info) {
|
||||
const StackTracesMap& stack_traces) {
|
||||
bool is_first_ref = false;
|
||||
{
|
||||
mutex_lock l(cache_mu_);
|
||||
@ -753,8 +753,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
||||
is_first_ref = registered_function->RefCountIsOne();
|
||||
}
|
||||
if (is_first_ref) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
|
||||
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
|
||||
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
|
||||
if (!add_to_local_only) {
|
||||
return MaybeRegisterFunctionRemotely(fdef);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user