commit
9e475fecd9
@ -11,10 +11,6 @@
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
`TF_StringEncodedSize` are no longer relevant and have been removed; see
|
||||
core/platform/ctstring.h for string access/modification in C.
|
||||
* In batching library, rename parameter
|
||||
SharedBatchScheduler::QueueOptions::max_batch_size to a more accurate name
|
||||
(input_batch_size_limit) for a recent feature to enable split of large batch
|
||||
sizes.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
||||
@ -171,6 +171,87 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_internal",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
@ -747,6 +828,7 @@ filegroup(
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"gradients.cc", # Uses RTTI.
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
||||
@ -725,13 +725,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
||||
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(new tfrt::ContextInterface(
|
||||
op_handler_chains, device_attributes, opts->async));
|
||||
return tensorflow::wrap(new tfrt::ContextInterface(opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
|
||||
400
tensorflow/c/eager/gradients.cc
Normal file
400
tensorflow/c/eager/gradients.cc
Normal file
@ -0,0 +1,400 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
Status GradientRegistry::Register(const string& op_name,
|
||||
GradientFunctionFactory factory) {
|
||||
auto iter = registry_.find(op_name);
|
||||
if (iter != registry_.end()) {
|
||||
const string error_msg = "Gradient already exists for op: " + op_name + ".";
|
||||
return errors::AlreadyExists(error_msg);
|
||||
}
|
||||
registry_.insert({op_name, factory});
|
||||
return Status::OK();
|
||||
}
|
||||
Status GradientRegistry::Lookup(
|
||||
const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const {
|
||||
auto iter = registry_.find(op.op_name);
|
||||
if (iter == registry_.end()) {
|
||||
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
grad_fn->reset(iter->second(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 ToId(AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Unref();
|
||||
}
|
||||
|
||||
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
|
||||
tensorflow::DataType TapeTensor::GetDType() const {
|
||||
return handle_->DataType();
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
|
||||
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
|
||||
// TODO(srbs): It seems like this is used only for performance optimization
|
||||
// and not for correctness. The only downside of keeping this 1 seems to be
|
||||
// that the gradient accumulation is unbounded and we will never
|
||||
// aggressively aggregate accumulated gradients to recover memory.
|
||||
// Revisit and fix.
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* TapeVSpace::AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
|
||||
if (gradient_tensors.size() == 1) {
|
||||
return gradient_tensors[0];
|
||||
}
|
||||
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
s = op->AddInputList(gradient_tensors);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status TapeVSpace::CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
return backward_function->Compute(output_gradients, result);
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
return ToId(tensor);
|
||||
}
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
return TapeTensor(g, ctx_);
|
||||
}
|
||||
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
|
||||
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
gradient->Release();
|
||||
}
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
forward_op_->op_name = op;
|
||||
return op_->Reset(op, raw_device_name);
|
||||
}
|
||||
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
|
||||
ForwardOperation* forward_op_) {
|
||||
TF_RETURN_IF_ERROR(op_->AddInput(input));
|
||||
forward_op_->inputs.push_back(input);
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddInputList(AbstractOperation* op_,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ForwardOperation* forward_op_) {
|
||||
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
|
||||
for (auto input : inputs) {
|
||||
forward_op_->inputs.push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
|
||||
const char* data, size_t length,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
|
||||
return op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
|
||||
return op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
|
||||
DataType value, ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims,
|
||||
ForwardOperation* forward_op_) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
forward_op_->attrs.Set(attr_name, proto);
|
||||
return op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
|
||||
const AbstractOperation* value,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
|
||||
const char* value, size_t length,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented "
|
||||
"yet.");
|
||||
}
|
||||
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
|
||||
AbstractTensorInterface* tensor,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values, ForwardOperation* forward_op_) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
forward_op_->attrs.Set(attr_name, v);
|
||||
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
|
||||
const float* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const float>(values, num_values));
|
||||
return op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
|
||||
const DataType* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const DataType>(values, num_values));
|
||||
return op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
|
||||
const unsigned char* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, ForwardOperation* forward_op_) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
forward_op_->attrs.Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been "
|
||||
"implemented yet.");
|
||||
}
|
||||
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
|
||||
ForwardOperation* forward_op_, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_->inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
|
||||
for (int i = 0; i < forward_op_->inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_->inputs[i]);
|
||||
input_dtypes[i] = forward_op_->inputs[i]->DataType();
|
||||
}
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t, ctx));
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[registry, forward_op_]() -> GradientFunction* {
|
||||
std::unique_ptr<GradientFunction> grad_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &grad_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return grad_fn.release();
|
||||
},
|
||||
[](GradientFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
171
tensorflow/c/eager/gradients.h
Normal file
171
tensorflow/c/eager/gradients.h
Normal file
@ -0,0 +1,171 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
|
||||
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tape.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
// =============== Experimental C++ API for computing gradients ===============
|
||||
|
||||
// Sample gradient function:
|
||||
//
|
||||
// class AddGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// (*grad_outputs)[0] = grad_inputs[0];
|
||||
// (*grad_outputs)[1] = grad_inputs[0];
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~AddGradientFunction() override {}
|
||||
// };
|
||||
//
|
||||
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
// // More complex gradient functions can use inputs/attrs etc. from the
|
||||
// // forward `op`.
|
||||
// return new AddGradientFunction;
|
||||
// }
|
||||
//
|
||||
// Status RegisterGradients(GradientRegistry* registry) {
|
||||
// return registry->Register("Add", AddRegisterer);
|
||||
// }
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
||||
// Metadata from the forward operation that is made available to the
|
||||
// gradient registerer to instantiate a GradientFunction.
|
||||
struct ForwardOperation {
|
||||
public:
|
||||
string op_name;
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
AttrBuilder attrs;
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
using GradientFunctionFactory =
|
||||
std::function<GradientFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `GradientFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
public:
|
||||
Status Register(const string& op, GradientFunctionFactory factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
||||
};
|
||||
|
||||
// Returns a unique id for the tensor which is used by the tape to build
|
||||
// the gradient graph. See documentation of `TapeTensor` for more details.
|
||||
int64 ToId(AbstractTensorHandle* t);
|
||||
|
||||
// Wrapper for a tensor output of an operation executing under a tape.
|
||||
//
|
||||
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
|
||||
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
|
||||
// the op that produced it (or -1 if this tensor was watched using
|
||||
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
|
||||
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
|
||||
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
|
||||
// inputs and outputs and the gradient function These data structures combined
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `ZerosLike` and `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
TapeTensor(const TapeTensor& other);
|
||||
~TapeTensor();
|
||||
|
||||
tensorflow::int64 GetID() const;
|
||||
tensorflow::DataType GetDType() const;
|
||||
|
||||
AbstractTensorHandle* OnesLike() const;
|
||||
AbstractTensorHandle* ZerosLike() const;
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike and ZerosLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// Vector space for actually computing gradients. Implements methods for calling
|
||||
// the backward function with incoming gradients and returning the outgoing
|
||||
// gradient and for performing gradient aggregation.
|
||||
// See `tensorflow::eager::VSpace` for more details.
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 NumElements(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
|
||||
|
||||
void MarkAsResult(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
void DeleteGradient(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
private:
|
||||
// The context where the aggregation op `Add` is to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// A tracing/immediate-execution agnostic tape.
|
||||
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
|
||||
GradientFunction, TapeTensor>;
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_
|
||||
87
tensorflow/c/eager/gradients_internal.h
Normal file
87
tensorflow/c/eager/gradients_internal.h
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
|
||||
// Records the op name in the `ForwardOperation`.
|
||||
Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
|
||||
ForwardOperation*);
|
||||
|
||||
// Records the inputs in the `ForwardOperation`.
|
||||
Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
|
||||
Status AddInputList(AbstractOperation*,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ForwardOperation*);
|
||||
|
||||
// Sets the attrs in the `ForwardOperation`.
|
||||
Status SetAttrString(AbstractOperation*, const char* attr_name,
|
||||
const char* data, size_t length, ForwardOperation*);
|
||||
Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrShape(AbstractOperation*, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims, ForwardOperation*);
|
||||
Status SetAttrFunction(AbstractOperation*, const char* attr_name,
|
||||
const AbstractOperation* value, ForwardOperation*);
|
||||
Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
|
||||
const char* value, size_t length, ForwardOperation*);
|
||||
Status SetAttrTensor(AbstractOperation*, const char* attr_name,
|
||||
AbstractTensorInterface* tensor, ForwardOperation*);
|
||||
Status SetAttrStringList(AbstractOperation*, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values, ForwardOperation*);
|
||||
Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
|
||||
const float* values, int num_values, ForwardOperation*);
|
||||
Status SetAttrIntList(AbstractOperation*, const char* attr_name,
|
||||
const int64_t* values, int num_values, ForwardOperation*);
|
||||
Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
|
||||
const DataType* values, int num_values,
|
||||
ForwardOperation*);
|
||||
Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
|
||||
const unsigned char* values, int num_values,
|
||||
ForwardOperation*);
|
||||
Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, ForwardOperation*);
|
||||
Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values,
|
||||
ForwardOperation*);
|
||||
|
||||
// Make the call to `Tape::RecordOperation`.
|
||||
Status Execute(AbstractOperation*, AbstractContext*,
|
||||
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
|
||||
ForwardOperation*, Tape*, const GradientRegistry&);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
||||
328
tensorflow/c/eager/gradients_test.cc
Normal file
328
tensorflow/c/eager/gradients_test.cc
Normal file
@ -0,0 +1,328 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
}
|
||||
};
|
||||
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// =================== Register gradients for Add ============================
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction(op.ctx);
|
||||
}
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
return registry->Register("Add", AddRegisterer);
|
||||
}
|
||||
|
||||
// =================== End gradient registrations ============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Release();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = outputs.size();
|
||||
output_list.outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(output_list.outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Release();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
output_list.outputs[0]->Release();
|
||||
output_list.outputs[1]->Release();
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size();
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestAddGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x + y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Release();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Release();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
@ -66,11 +66,26 @@ cc_library(
|
||||
":file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ram_file_block_cache_test",
|
||||
size = "small",
|
||||
srcs = ["ram_file_block_cache_test.cc"],
|
||||
deps = [
|
||||
":ram_file_block_cache",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"//tensorflow/core/platform/cloud:now_seconds_env",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gcs_filesystem_test",
|
||||
srcs = [
|
||||
|
||||
@ -45,8 +45,6 @@ limitations under the License.
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// A move-only RAII object that calls a stored cleanup functor when
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
@ -556,6 +557,111 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): This approach can cause a problem when our path is
|
||||
// `path/to/dir` and there is an object with key `path/to/directory`. Will be
|
||||
// fixed when refactoring.
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
for (auto&& metadata :
|
||||
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
// We consider a path exists if there is at least one object whose key
|
||||
// contains the path.
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
return TF_SetStatus(
|
||||
status, TF_NOT_FOUND,
|
||||
absl::StrCat("The path ", path, " does not exist.").c_str());
|
||||
}
|
||||
|
||||
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
if (object.empty()) {
|
||||
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
|
||||
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_OK)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
// We check if there is an object with this key on the GCS server.
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
if (metadata) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (metadata->name().back() == '/')
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
// If there is no object with this key on the GCS server. We check if there is
|
||||
// any object whose key contains that path.
|
||||
MaybeAppendSlash(&object);
|
||||
for (auto&& metadata :
|
||||
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return false;
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return true;
|
||||
}
|
||||
TF_SetStatus(status, TF_NOT_FOUND,
|
||||
absl::StrCat("The path ", path, " does not exist.").c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
if (object.empty()) {
|
||||
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
|
||||
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
stats->is_directory = true;
|
||||
stats->length = 0;
|
||||
stats->mtime_nsec = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (IsDirectory(filesystem, path, status)) {
|
||||
stats->is_directory = true;
|
||||
stats->length = 0;
|
||||
stats->mtime_nsec = 0;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
if (metadata) {
|
||||
stats->is_directory = false;
|
||||
stats->length = metadata.value().size();
|
||||
stats->mtime_nsec = metadata.value()
|
||||
.time_storage_class_updated()
|
||||
.time_since_epoch()
|
||||
.count();
|
||||
}
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
|
||||
@ -51,7 +51,7 @@ class RamFileBlockCache : public FileBlockCache {
|
||||
|
||||
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
|
||||
BlockFetcher block_fetcher,
|
||||
std::function<uint64_t()> timer_seconds)
|
||||
std::function<uint64_t()> timer_seconds = TF_NowSeconds)
|
||||
: block_size_(block_size),
|
||||
max_bytes_(max_bytes),
|
||||
max_staleness_(max_staleness),
|
||||
|
||||
@ -0,0 +1,601 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/cloud/now_seconds_env.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache,
|
||||
const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
out->clear();
|
||||
out->resize(n, 0);
|
||||
size_t bytes_transferred = 0;
|
||||
TF_Status status;
|
||||
cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status);
|
||||
EXPECT_LE(bytes_transferred, n);
|
||||
out->resize(bytes_transferred, n);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, IsCacheEnabled) {
|
||||
auto fetcher = [](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
// Do nothing.
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache3(0, 32, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache4(16, 32, 0, fetcher);
|
||||
|
||||
EXPECT_FALSE(cache1.IsCacheEnabled());
|
||||
EXPECT_FALSE(cache2.IsCacheEnabled());
|
||||
EXPECT_FALSE(cache3.IsCacheEnabled());
|
||||
EXPECT_TRUE(cache4.IsCacheEnabled());
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
string filename = "file";
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
|
||||
// First read.
|
||||
EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123));
|
||||
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
|
||||
// Second read. Hit cache.
|
||||
EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123));
|
||||
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
|
||||
// Third read. File signatures are different.
|
||||
EXPECT_FALSE(cache.ValidateAndUpdateFileSignature(filename, 321));
|
||||
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
|
||||
EXPECT_EQ(calls, 2);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, PassThrough) {
|
||||
const string want_filename = "foo/bar";
|
||||
const size_t want_offset = 42;
|
||||
const size_t want_n = 1024;
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls, want_filename, want_offset, want_n](
|
||||
const string& got_filename, size_t got_offset,
|
||||
size_t got_n, char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(got_filename, want_filename);
|
||||
EXPECT_EQ(got_offset, want_offset);
|
||||
EXPECT_EQ(got_n, want_n);
|
||||
calls++;
|
||||
memset(buffer, 'x', got_n);
|
||||
*bytes_transferred = got_n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
// If block_size, max_bytes, or both are zero, or want_n is larger than
|
||||
// max_bytes the cache is a pass-through.
|
||||
tf_gcs_filesystem::RamFileBlockCache cache1(1, 0, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache2(0, 1, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache3(0, 0, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache4(1000, 1000, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 2);
|
||||
TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
TF_EXPECT_OK(ReadCache(&cache4, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 4);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, BlockAlignment) {
|
||||
// Initialize a 256-byte buffer. This is the file underlying the reads we'll
|
||||
// do in this test.
|
||||
const size_t size = 256;
|
||||
std::vector<char> buf;
|
||||
for (int i = 0; i < size; i++) {
|
||||
buf.push_back(i);
|
||||
}
|
||||
// The fetcher just fetches slices of the buffer.
|
||||
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
if (offset < buf.size()) {
|
||||
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
|
||||
memcpy(buffer, buf.data() + offset, bytes_to_copy);
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
} else {
|
||||
*bytes_transferred = 0;
|
||||
}
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
for (size_t block_size = 2; block_size <= 4; block_size++) {
|
||||
// Make a cache of N-byte block size (1 block) and verify that reads of
|
||||
// varying offsets and lengths return correct data.
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
|
||||
fetcher);
|
||||
for (size_t offset = 0; offset < 10; offset++) {
|
||||
for (size_t n = block_size - 2; n <= block_size + 2; n++) {
|
||||
std::vector<char> got;
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
|
||||
// Verify the size of the read.
|
||||
if (offset + n <= size) {
|
||||
// Expect a full read.
|
||||
EXPECT_EQ(got.size(), n) << "block size = " << block_size
|
||||
<< ", offset = " << offset << ", n = " << n;
|
||||
} else {
|
||||
// Expect a partial read.
|
||||
EXPECT_EQ(got.size(), size - offset)
|
||||
<< "block size = " << block_size << ", offset = " << offset
|
||||
<< ", n = " << n;
|
||||
}
|
||||
// Verify the contents of the read.
|
||||
std::vector<char>::const_iterator begin = buf.begin() + offset;
|
||||
std::vector<char>::const_iterator end =
|
||||
offset + n > buf.size() ? buf.end() : begin + n;
|
||||
std::vector<char> want(begin, end);
|
||||
EXPECT_EQ(got, want) << "block size = " << block_size
|
||||
<< ", offset = " << offset << ", n = " << n;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, CacheHits) {
|
||||
const size_t block_size = 16;
|
||||
std::set<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
|
||||
calls.insert(offset);
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
const uint32 block_count = 256;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
block_size, block_count * block_size, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
out.resize(block_count, 0);
|
||||
// The cache has space for `block_count` blocks. The loop with i = 0 should
|
||||
// fill the cache, and the loop with i = 1 should be all cache hits. The
|
||||
// fetcher checks that it is called once and only once for each offset (to
|
||||
// fetch the corresponding block).
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < block_count; j++) {
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, OutOfRange) {
|
||||
// Tests reads of a 24-byte file with block size 16.
|
||||
const size_t block_size = 16;
|
||||
const size_t file_size = 24;
|
||||
bool first_block = false;
|
||||
bool second_block = false;
|
||||
auto fetcher = [block_size, file_size, &first_block, &second_block](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
size_t bytes_to_copy = 0;
|
||||
if (offset == 0) {
|
||||
// The first block (16 bytes) of the file.
|
||||
memset(buffer, 'x', n);
|
||||
bytes_to_copy = n;
|
||||
first_block = true;
|
||||
} else if (offset == block_size) {
|
||||
// The second block (8 bytes) of the file.
|
||||
bytes_to_copy = file_size - block_size;
|
||||
memset(buffer, 'x', bytes_to_copy);
|
||||
second_block = true;
|
||||
}
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
|
||||
fetcher);
|
||||
std::vector<char> out;
|
||||
// Reading the first 16 bytes should be fine.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
|
||||
EXPECT_TRUE(first_block);
|
||||
EXPECT_EQ(out.size(), block_size);
|
||||
// Reading at offset file_size + 4 will read the second block (since the read
|
||||
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
|
||||
// OutOfRange because the offset is past the end of the 24-byte file.
|
||||
Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
|
||||
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
|
||||
EXPECT_TRUE(second_block);
|
||||
// Reading the second full block will return 8 bytes, from a cache hit.
|
||||
second_block = false;
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
|
||||
EXPECT_FALSE(second_block);
|
||||
EXPECT_EQ(out.size(), file_size - block_size);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, Inconsistent) {
|
||||
// Tests the detection of interrupted reads leading to partially filled blocks
|
||||
// where we expected complete blocks.
|
||||
const size_t block_size = 16;
|
||||
// This fetcher returns OK but only fills in one byte for any offset.
|
||||
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
EXPECT_GE(n, 1);
|
||||
memset(buffer, 'x', 1);
|
||||
*bytes_transferred = 1;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0,
|
||||
fetcher);
|
||||
std::vector<char> out;
|
||||
// Read the second block; this should yield an OK status and a single byte.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
|
||||
EXPECT_EQ(out.size(), 1);
|
||||
// Now read the first block; this should yield an INTERNAL error because we
|
||||
// had already cached a partial block at a later position.
|
||||
Status status = ReadCache(&cache, "", 0, block_size, &out);
|
||||
EXPECT_EQ(status.code(), error::INTERNAL);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, LRU) {
|
||||
const size_t block_size = 16;
|
||||
std::list<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
|
||||
if (!calls.empty()) {
|
||||
EXPECT_EQ(offset, calls.front());
|
||||
calls.pop_front();
|
||||
}
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
const uint32 block_count = 2;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
block_size, block_count * block_size, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
// Read blocks from the cache, and verify the LRU behavior based on the
|
||||
// fetcher calls that the cache makes.
|
||||
calls.push_back(0);
|
||||
// Cache miss - drains an element from `calls`.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
// Cache hit - does not drain an element from `calls`.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
calls.push_back(block_size);
|
||||
// Cache miss followed by cache hit.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
calls.push_back(2 * block_size);
|
||||
// Cache miss followed by cache hit. Causes eviction of LRU element.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
// LRU element was at offset 0. Cache miss.
|
||||
calls.push_back(0);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
// Element at 2 * block_size is still in cache, and this read should update
|
||||
// its position in the LRU list so it doesn't get evicted by the next read.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
// Element at block_size was evicted. Reading this element will also cause
|
||||
// the LRU element (at 0) to be evicted.
|
||||
calls.push_back(block_size);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
// Element at 0 was evicted again.
|
||||
calls.push_back(0);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, MaxStaleness) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
std::vector<char> out;
|
||||
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
|
||||
// Create a cache with max staleness of 2 seconds, and verify that it works as
|
||||
// expected.
|
||||
tf_gcs_filesystem::RamFileBlockCache cache1(
|
||||
8, 16, 2 /* max staleness */, fetcher,
|
||||
[&env]() { return env->NowSeconds(); });
|
||||
// Execute the first read to load the block.
|
||||
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
// Now advance the clock one second at a time and redo the read. The call
|
||||
// count should advance every 3 seconds (i.e. every time the staleness is
|
||||
// greater than 2).
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
env->SetNowSeconds(i + 1);
|
||||
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1 + i / 3);
|
||||
}
|
||||
// Now create a cache with max staleness of 0, and verify that it also works
|
||||
// as expected.
|
||||
calls = 0;
|
||||
env->SetNowSeconds(0);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache2(
|
||||
8, 16, 0 /* max staleness */, fetcher,
|
||||
[&env]() { return env->NowSeconds(); });
|
||||
// Execute the first read to load the block.
|
||||
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
// Advance the clock by a huge amount and verify that the cached block is
|
||||
// used to satisfy the read.
|
||||
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
|
||||
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, RemoveFile) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
calls++;
|
||||
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
|
||||
if (offset > 0) {
|
||||
// The first block is lower case and all subsequent blocks are upper case.
|
||||
c = toupper(c);
|
||||
}
|
||||
memset(buffer, c, n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
// This cache has space for 4 blocks; we'll read from two files.
|
||||
const size_t n = 3;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(8, 32, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
std::vector<char> a(n, 'a');
|
||||
std::vector<char> b(n, 'b');
|
||||
std::vector<char> A(n, 'A');
|
||||
std::vector<char> B(n, 'B');
|
||||
// Fill the cache.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
EXPECT_EQ(calls, 1);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
EXPECT_EQ(calls, 2);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
EXPECT_EQ(calls, 3);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// All four blocks should be in the cache now.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// Remove the blocks from "a".
|
||||
cache.RemoveFile("a");
|
||||
// Both blocks from "b" should still be there.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// The blocks from "a" should not be there.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
EXPECT_EQ(calls, 5);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
EXPECT_EQ(calls, 6);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, Prune) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
std::vector<char> out;
|
||||
// Our fake environment is initialized with the current timestamp.
|
||||
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
|
||||
uint64 now = Env::Default()->NowSeconds();
|
||||
env->SetNowSeconds(now);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
8, 32, 1 /* max staleness */, fetcher,
|
||||
[&env]() { return env->NowSeconds(); });
|
||||
// Read three blocks into the cache, and advance the timestamp by one second
|
||||
// with each read. Start with a block of "a" at the current timestamp `now`.
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
|
||||
// Now load a block of a different file "b" at timestamp `now` + 1
|
||||
env->SetNowSeconds(now + 1);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
// Now load a different block of file "a" at timestamp `now` + 1. When the
|
||||
// first block of "a" expires, this block should also be removed because it
|
||||
// also belongs to file "a".
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
|
||||
// Ensure that all blocks are in the cache (i.e. reads are cache hits).
|
||||
EXPECT_EQ(cache.CacheSize(), 24);
|
||||
EXPECT_EQ(calls, 3);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
// Advance the fake timestamp so that "a" becomes stale via its first block.
|
||||
env->SetNowSeconds(now + 2);
|
||||
// The pruning thread periodically compares env->NowSeconds() with the oldest
|
||||
// block's timestamp to see if it should evict any files. At the current fake
|
||||
// timestamp of `now` + 2, file "a" is stale because its first block is stale,
|
||||
// but file "b" is not stale yet. Thus, once the pruning thread wakes up (in
|
||||
// one second of wall time), it should remove "a" and leave "b" alone.
|
||||
uint64 start = Env::Default()->NowSeconds();
|
||||
do {
|
||||
Env::Default()->SleepForMicroseconds(100000);
|
||||
} while (cache.CacheSize() == 24 && Env::Default()->NowSeconds() - start < 3);
|
||||
// There should be one block left in the cache, and it should be the first
|
||||
// block of "b".
|
||||
EXPECT_EQ(cache.CacheSize(), 8);
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
// Advance the fake time to `now` + 3, at which point "b" becomes stale.
|
||||
env->SetNowSeconds(now + 3);
|
||||
// Wait for the pruner to remove "b".
|
||||
start = Env::Default()->NowSeconds();
|
||||
do {
|
||||
Env::Default()->SleepForMicroseconds(100000);
|
||||
} while (cache.CacheSize() == 8 && Env::Default()->NowSeconds() - start < 3);
|
||||
// The cache should now be empty.
|
||||
EXPECT_EQ(cache.CacheSize(), 0);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, ParallelReads) {
|
||||
// This fetcher won't respond until either `callers` threads are calling it
|
||||
// concurrently (at which point it will respond with success to all callers),
|
||||
// or 10 seconds have elapsed (at which point it will respond with an error).
|
||||
const int callers = 4;
|
||||
BlockingCounter counter(callers);
|
||||
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
counter.DecrementCount();
|
||||
if (!counter.WaitFor(std::chrono::seconds(10))) {
|
||||
// This avoids having the test time out, which is harder to debug.
|
||||
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"desired concurrency not reached");
|
||||
}
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
const int block_size = 8;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
block_size, 2 * callers * block_size, 0, fetcher);
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
threads.reserve(callers);
|
||||
for (int i = 0; i < callers; i++) {
|
||||
threads.emplace_back(
|
||||
Env::Default()->StartThread({}, "caller", [block_size, &cache, i]() {
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(
|
||||
ReadCache(&cache, "a", i * block_size, block_size, &out));
|
||||
std::vector<char> x(block_size, 'x');
|
||||
EXPECT_EQ(out, x);
|
||||
}));
|
||||
}
|
||||
// The `threads` destructor blocks until the threads can be joined, once their
|
||||
// respective reads finish (which happens once they are all concurrently being
|
||||
// executed, or 10 seconds have passed).
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
|
||||
// Concurrent reads to the same file blocks should be de-duplicated.
|
||||
const size_t block_size = 16;
|
||||
int num_requests = 0;
|
||||
Notification notification;
|
||||
auto fetcher = [&num_requests, ¬ification, block_size](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset, 0);
|
||||
num_requests++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
notification.Notify();
|
||||
// Wait for other thread to issue read.
|
||||
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
|
||||
fetcher);
|
||||
// Fork off thread for parallel read.
|
||||
std::unique_ptr<Thread> concurrent(
|
||||
Env::Default()->StartThread({}, "concurrent", [block_size, &cache] {
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
|
||||
EXPECT_EQ(out.size(), block_size / 2);
|
||||
}));
|
||||
notification.WaitForNotification();
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
|
||||
EXPECT_EQ(out.size(), block_size / 2);
|
||||
|
||||
EXPECT_EQ(1, num_requests);
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, Flush) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
cache.Flush();
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
|
||||
EXPECT_EQ(calls, 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
@ -97,6 +97,11 @@ void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
|
||||
kernel_builder->cc_builder->HostMemory(arg_name);
|
||||
}
|
||||
|
||||
void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
|
||||
int32_t priority_number) {
|
||||
kernel_builder->cc_builder->Priority(priority_number);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
|
||||
@ -107,6 +107,10 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint(
|
||||
TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
|
||||
TF_KernelBuilder* kernel_builder, const char* arg_name);
|
||||
|
||||
// Specify a priority number for this kernel.
|
||||
TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
|
||||
TF_KernelBuilder* kernel_builder, int32_t priority_number);
|
||||
|
||||
// Register the given kernel builder with the TensorFlow runtime. If
|
||||
// registration fails, the given status will be populated.
|
||||
//
|
||||
|
||||
@ -359,6 +359,7 @@ cc_library(
|
||||
":map_lmhlo_to_scalar_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -692,7 +693,8 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
||||
if (valueAttr.getType().getRank() != 0) return failure();
|
||||
auto stdConstOp =
|
||||
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
|
||||
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
|
||||
ValueRange());
|
||||
rewriter.eraseOp(constOp);
|
||||
return success();
|
||||
}
|
||||
@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
AffineDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
|
||||
@ -329,7 +329,7 @@ func @constant(%value: memref<i32>) {
|
||||
return
|
||||
}
|
||||
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
|
||||
// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32>
|
||||
// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref<i32>
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@ -152,6 +152,14 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
|
||||
uint32_t flags =
|
||||
is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
|
||||
|
||||
// Rejects if quantized tensors have zero scales.
|
||||
for (float scale : quant_params.scale) {
|
||||
if (scale == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Quantized tensors must have non-zero scales");
|
||||
}
|
||||
}
|
||||
|
||||
// Scale size can't be zero as it is checked before.
|
||||
if (quant_params.scale.size() != 1) {
|
||||
llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
|
||||
|
||||
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -1676,12 +1676,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [
|
||||
}
|
||||
|
||||
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
FixedOutputRangeInterface,
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
|
||||
FixedOutputRangeInterface]> {
|
||||
let summary = "L2 Normalize Operator";
|
||||
|
||||
let description = [{
|
||||
@ -1703,29 +1698,12 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
// FixedOutputRangeInterface:
|
||||
quant::UniformQuantizedType GetFixedOutputRange(
|
||||
bool is_signed, int bit_width) {
|
||||
auto result_type = output().getType().cast<ShapedType>();
|
||||
if (!result_type.getElementType().isa<FloatType>()) return {};
|
||||
Builder builder(result_type.getContext());
|
||||
|
||||
// Only support 8-bits
|
||||
if (bit_width != 8) return {};
|
||||
IntegerType storage_type = builder.getIntegerType(bit_width);
|
||||
|
||||
double scale = 1.0 / 128;
|
||||
int64_t zero_point, storage_min, storage_max;
|
||||
if (is_signed) {
|
||||
zero_point = 0;
|
||||
storage_min = -128;
|
||||
storage_max = 127;
|
||||
} else {
|
||||
zero_point = 128;
|
||||
storage_min = 0;
|
||||
storage_max = 255;
|
||||
}
|
||||
|
||||
return quant::UniformQuantizedType::getChecked(
|
||||
is_signed, storage_type, result_type.getElementType(), scale,
|
||||
zero_point, storage_min, storage_max, builder.getUnknownLoc());
|
||||
auto result_type = output().getType();
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
|
||||
/*scale=*/1.0 / 128, /*zero_point=*/0);
|
||||
}
|
||||
}];
|
||||
}
|
||||
@ -1834,10 +1812,6 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Logistic operator";
|
||||
@ -1854,29 +1828,11 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
// FixedOutputRangeInterface:
|
||||
quant::UniformQuantizedType GetFixedOutputRange(
|
||||
bool is_signed, int bit_width) {
|
||||
auto result_type = y().getType().cast<ShapedType>();
|
||||
if (!result_type.getElementType().isa<FloatType>()) return {};
|
||||
Builder builder(result_type.getContext());
|
||||
|
||||
// Only support 8-bits
|
||||
if (bit_width != 8) return {};
|
||||
IntegerType storage_type = builder.getIntegerType(bit_width);
|
||||
|
||||
double scale = 1.0 / 256;
|
||||
int64_t zero_point, storage_min, storage_max;
|
||||
if (is_signed) {
|
||||
zero_point = -128;
|
||||
storage_min = -128;
|
||||
storage_max = 127;
|
||||
} else {
|
||||
zero_point = 0;
|
||||
storage_min = 0;
|
||||
storage_max = 255;
|
||||
}
|
||||
|
||||
return quant::UniformQuantizedType::getChecked(
|
||||
is_signed, storage_type, result_type.getElementType(), scale,
|
||||
zero_point, storage_min, storage_max, builder.getUnknownLoc());
|
||||
auto result_type = y().getType();
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
|
||||
/*scale=*/1.0 / 256, /*zero_point=*/-128);
|
||||
}
|
||||
}];
|
||||
}
|
||||
@ -1905,10 +1861,7 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
SameOperandsAndResultShape,
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<255, 625, -4>>]> {
|
||||
FixedOutputRangeInterface]> {
|
||||
let summary = "Log softmax operator";
|
||||
|
||||
let description = [{
|
||||
@ -1922,6 +1875,18 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// FixedOutputRangeInterface:
|
||||
quant::UniformQuantizedType GetFixedOutputRange(
|
||||
bool is_signed, int bit_width) {
|
||||
auto result_type = output().getType();
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
|
||||
/*scale=*/16.0 / 256, /*zero_point=*/127);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(ashwinm): Revisit the granularity of the PredOpTraits. We could
|
||||
@ -2833,10 +2798,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankRange<0, 1, 4>,
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Softmax operator";
|
||||
|
||||
@ -2854,6 +2816,18 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// FixedOutputRangeInterface:
|
||||
quant::UniformQuantizedType GetFixedOutputRange(
|
||||
bool is_signed, int bit_width) {
|
||||
auto result_type = output().getType();
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
|
||||
/*scale=*/1.0 / 256, /*zero_point=*/-128);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
@ -2959,11 +2933,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
SameOperandsAndResultShape,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hyperbolic tangent operator";
|
||||
|
||||
@ -2985,6 +2955,19 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
state.addTypes(input.getType());
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// FixedOutputRangeInterface:
|
||||
quant::UniformQuantizedType GetFixedOutputRange(
|
||||
bool is_signed, int bit_width) {
|
||||
auto result_type = output().getType();
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
|
||||
/*scale=*/1.0 / 128, /*zero_point=*/0);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [
|
||||
|
||||
@ -794,16 +794,18 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
auto spec = GetQuantSpec(op);
|
||||
auto key = std::make_pair(8, is_signed_);
|
||||
auto &restricted_outputs = spec->restricted_output_params[key];
|
||||
for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
|
||||
// The restrict can be nullptr if the result has been quantized.
|
||||
if (auto params = restricted_outputs[i]) {
|
||||
changed |= SetResultParams(op, i, params);
|
||||
if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
|
||||
// TODO(fengliuai): different result can have different fixed range.
|
||||
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
// The range is null if the result has been quantized.
|
||||
if (params) {
|
||||
changed |= SetResultParams(op, i, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto spec = GetQuantSpec(op);
|
||||
for (auto &it : spec->biases_params) {
|
||||
auto params =
|
||||
GetBiasParams(op, it.first, it.second.first, it.second.second);
|
||||
|
||||
@ -449,7 +449,7 @@ static bool PreferResultScale(Operation* op) {
|
||||
// only considers the ops with restricted output params.
|
||||
static bool IsStatsRedundant(Operation* op,
|
||||
OpQuantSpecGetter op_quant_spec_getter) {
|
||||
return !op_quant_spec_getter(op)->restricted_output_params.empty();
|
||||
return llvm::isa<FixedOutputRangeInterface>(op);
|
||||
}
|
||||
|
||||
bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
@ -469,7 +469,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
|
||||
// Step 1: forward pass: propagate any value scales which are not produces
|
||||
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales
|
||||
// which are produced by the `restricted_output_params`.
|
||||
// which are produced by the ops with the `FixedOutputRangeInterface`.
|
||||
// Note that we don't propagate across the multiple-operands
|
||||
// `SameOperandsAndResultsScale` ops like `concatenation`.
|
||||
func.walk(
|
||||
@ -594,5 +594,27 @@ LogicalResult VerifySameScales(Operation* op) {
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
|
||||
Type tensor_type, double scale,
|
||||
int64_t zero_point,
|
||||
int64_t storage_min,
|
||||
int64_t storage_max) {
|
||||
auto result_type = tensor_type.cast<ShapedType>();
|
||||
if (!result_type.getElementType().isa<FloatType>()) return {};
|
||||
Builder builder(result_type.getContext());
|
||||
|
||||
// Only support 8-bits
|
||||
if (bit_width != 8) return {};
|
||||
IntegerType storage_type = builder.getIntegerType(bit_width);
|
||||
if (!is_signed) {
|
||||
zero_point += 128;
|
||||
storage_min += 128;
|
||||
storage_max += 128;
|
||||
}
|
||||
return quant::UniformQuantizedType::getChecked(
|
||||
is_signed, storage_type, result_type.getElementType(), scale, zero_point,
|
||||
storage_min, storage_max, builder.getUnknownLoc());
|
||||
}
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
@ -395,8 +395,6 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
|
||||
llvm::SmallVector<Type, 4> new_output_types;
|
||||
for (auto result : def->getResults()) {
|
||||
result.getUsers().begin()->dump();
|
||||
op.dump();
|
||||
if (result.hasOneUse() && *result.getUsers().begin() == op) {
|
||||
new_output_types.push_back(op.qtype());
|
||||
} else {
|
||||
@ -502,6 +500,13 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
OpQuantSpecGetter op_quant_spec_getter);
|
||||
|
||||
// Given quantization parameters for int8, compute the quantization parameters
|
||||
// for uint if it is required, and wrap the result in an UniformQuantizedType.
|
||||
quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
|
||||
Type tensor_type, double scale,
|
||||
int64_t zero_point,
|
||||
int64_t storage_min = -128,
|
||||
int64_t storage_max = 127);
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
module{
|
||||
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
|
||||
@ -453,3 +453,31 @@ func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %
|
||||
// CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<1x10xi32>} : () -> tensor<1x10xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<i32>
|
||||
return %0, %1 : tensor<1x10xi32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK: func @nms_padded(%[[VAL_119:.*]]: tensor<100x4xf32>, %[[VAL_120:.*]]: tensor<100xf32>, %[[VAL_121:.*]]: tensor<i32>, %[[VAL_122:.*]]: tensor<f32>, %[[VAL_123:.*]]: tensor<f32>, %[[VAL_124:.*]]: tensor<i1>, %[[VAL_125:.*]]: tensor<i1>, %[[VAL_126:.*]]: tensor<i1>, %[[VAL_127:.*]]: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
|
||||
// CHECK: %[[VAL_128:.*]], %[[VAL_129:.*]] = "tfl.non_max_suppression_v4"(%[[VAL_119]], %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]]) : (tensor<100x4xf32>, tensor<100xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>)
|
||||
// CHECK: return %[[VAL_128]], %[[VAL_129]] : tensor<1x10xi32>, tensor<i32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{Invalid number of results from non_max_suppression_padded_v2}}
|
||||
func @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}}
|
||||
func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
|
||||
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
}
|
||||
|
||||
@ -188,20 +188,16 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
if (saved_model_version == 2) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
|
||||
@ -57,6 +57,7 @@ namespace {
|
||||
|
||||
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
||||
constexpr char kTfTextAPIPRefix[] = "tftext:";
|
||||
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
||||
|
||||
// Abstracts the conversion of the embedded lookup composite function.
|
||||
class ConvertEmbeddedLookupFunc {
|
||||
@ -94,6 +95,59 @@ class ConvertEmbeddedLookupFunc {
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
// Abstracts the conversion of the padded NMS composite function.
|
||||
class ConvertNMSPaddedFunc {
|
||||
public:
|
||||
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
|
||||
|
||||
void RewriteFunc() {
|
||||
func_.setAttr(kTFImplements,
|
||||
StringAttr::get(kTfNMSPadded, func_.getContext()));
|
||||
Value boxes = func_.getArgument(0);
|
||||
Value scores = func_.getArgument(1);
|
||||
Value max_output_size = func_.getArgument(2);
|
||||
Value iou_threshold = func_.getArgument(3);
|
||||
Value score_threshold = func_.getArgument(4);
|
||||
auto output_type0 = func_.getType().getResult(0);
|
||||
auto output_type1 = func_.getType().getResult(1);
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
|
||||
func_.getLoc(), output_type0, output_type1, boxes, scores,
|
||||
max_output_size, iou_threshold, score_threshold);
|
||||
|
||||
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
|
||||
}
|
||||
|
||||
LogicalResult VerifySignature() {
|
||||
// Verify high-level function signature.
|
||||
// Relevant argument characteristics are checked by the TFL op definition.
|
||||
if (func_.getNumArguments() < 5) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of arguments to "
|
||||
"non_max_suppression_padded_v2 (need atleast 5): "
|
||||
<< func_.getNumArguments();
|
||||
}
|
||||
if (func_.getType().getNumResults() != 2) {
|
||||
return func_.emitError() << "Invalid number of results from "
|
||||
"non_max_suppression_padded_v2 (need 2): "
|
||||
<< func_.getType().getNumResults();
|
||||
}
|
||||
// The TFLite fused op does not support batching yet.
|
||||
// TODO(b/158709815): Add support for batches with padded NMS.
|
||||
auto boxes_type =
|
||||
func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
|
||||
return func_.emitError() << "TFLite does not support batched input for "
|
||||
"non_max_suppression_padded";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
// This pass uses mechanisms listed in RFC:
|
||||
// https://github.com/tensorflow/community/pull/113
|
||||
// It prepares composite functions that are attributed to indicate
|
||||
@ -139,6 +193,14 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
|
||||
if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
} else if (attr.getValue() == kTfNMSPadded) {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
ConvertNMSPaddedFunc convert_nms_padded(func);
|
||||
if (failed(convert_nms_padded.VerifySignature())) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
convert_nms_padded.RewriteFunc();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -188,12 +188,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
|
||||
if (!is_enabled) {
|
||||
VLOG(0) << "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
RegisterDialects();
|
||||
|
||||
@ -22,6 +22,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
],
|
||||
|
||||
@ -15,7 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -29,6 +33,21 @@ PYBIND11_MODULE(mlir_wrapper, m) {
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
});
|
||||
m.def("verify", [](std::string input) {
|
||||
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
mlir::MLIRContext ctx;
|
||||
auto module = mlir::parseSourceFile(SM, &ctx);
|
||||
if (!module) {
|
||||
return false;
|
||||
}
|
||||
if (failed(mlir::verify(*module))) {
|
||||
module->emitError("Invalid MLIR module: failed verification.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
init_basic_classes(m);
|
||||
init_types(m);
|
||||
|
||||
@ -1194,6 +1194,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
||||
@ -3540,7 +3540,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
@ -3561,15 +3561,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
@ -3585,6 +3576,27 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2
|
||||
);
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
}
|
||||
|
||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||
let summary = "Gather slices from `params` according to `indices`.";
|
||||
|
||||
@ -8266,6 +8278,39 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]>
|
||||
];
|
||||
}
|
||||
|
||||
def TF_SelfAdjointEigV2Op : TF_Op<"SelfAdjointEigV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the eigen decomposition of one or more square self-adjoint matrices.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
|
||||
`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
|
||||
are sorted in non-decreasing order.
|
||||
|
||||
```python
|
||||
# a is a tensor.
|
||||
# e is a tensor of eigenvalues.
|
||||
# v is a tensor of eigenvectors.
|
||||
e, v = self_adjoint_eig(a)
|
||||
e = self_adjoint_eig(a, compute_v=False)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$compute_v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$e,
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$v
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SeluOp : TF_Op<"Selu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
|
||||
|
||||
@ -1977,13 +1977,55 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormV2Op / FusedBatchNormV3Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <class Op>
|
||||
static LogicalResult InferenceFoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation, Op *op) {
|
||||
// FusedBatchNorm in training mode is a layout sentitive operation, and should
|
||||
// have already assigned an optimal data format.
|
||||
if (is_training()) return failure();
|
||||
if (op->is_training()) return failure();
|
||||
return ::mlir::TF::FoldOperandsPermutation(permutation, op);
|
||||
}
|
||||
|
||||
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
|
||||
template <class Op>
|
||||
static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
|
||||
// In inference mode FusedBatchNorm is not sensitive to data layout.
|
||||
if (!op->is_training()) return op->data_format();
|
||||
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
|
||||
return op->data_format();
|
||||
|
||||
// For f16 data type on devices with Tensor Cores support NHWC data format
|
||||
// is up to ~2x faster.
|
||||
auto x_ty = op->x().getType().template cast<TensorType>();
|
||||
const bool is_f16 = x_ty.getElementType().isF16();
|
||||
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
|
||||
|
||||
// For all other data types prefer NCHW.
|
||||
return "NCHW";
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
|
||||
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||
}
|
||||
|
||||
StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
return ::mlir::TF::GetOptimalLayout(devices, this);
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
|
||||
@ -1991,22 +2033,7 @@ LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
|
||||
}
|
||||
|
||||
StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
// In inference mode FusedBatchNorm is not sensitive to data layout.
|
||||
if (!is_training()) return data_format();
|
||||
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||
return data_format();
|
||||
|
||||
// For f16 data type on devices with Tensor Cores support NHWC data format
|
||||
// is up to ~2x faster.
|
||||
auto x_ty = x().getType().cast<TensorType>();
|
||||
const bool is_f16 = x_ty.getElementType().isF16();
|
||||
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
|
||||
|
||||
// For all other data types prefer NCHW.
|
||||
return "NCHW";
|
||||
return ::mlir::TF::GetOptimalLayout(devices, this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2156,10 +2183,6 @@ static LogicalResult Verify(IfRegionOp op) {
|
||||
return failure();
|
||||
if (failed(VerifyRegionResults(op, op.else_branch(), "else")))
|
||||
return failure();
|
||||
if (op.then_branch().front().getNumArguments() != 0)
|
||||
return op.emitOpError() << "then region cannot have any arguments";
|
||||
if (op.else_branch().front().getNumArguments() != 0)
|
||||
return op.emitOpError() << "else region cannot have any arguments";
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@ -247,7 +247,7 @@ def TF_YieldOp : TF_Op<"Yield",
|
||||
}
|
||||
|
||||
def TF_IfRegionOp : TF_Op<"IfRegion",
|
||||
[SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
|
||||
let summary = "output = cond ? then_branch output : else_branch output";
|
||||
|
||||
let description = [{
|
||||
@ -524,8 +524,8 @@ are sparse matrices.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$a,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$b,
|
||||
TensorOf<[BF16, F32]>:$a,
|
||||
TensorOf<[BF16, F32]>:$b,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
|
||||
@ -535,7 +535,7 @@ are sparse matrices.
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$product
|
||||
TensorOf<[F32]>:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
@ -443,7 +443,7 @@ func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
|
||||
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
|
||||
}
|
||||
|
||||
// Do not fold if total result size is large (>128 KB) and more than 2 times
|
||||
// Do not fold if total result size is large (>256 KB) and more than 2 times
|
||||
// the size of operands.
|
||||
|
||||
// LINT.IfChange(folding-policy-test)
|
||||
|
||||
@ -1055,7 +1055,7 @@ func @testIfRegionThenConsumingElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
|
||||
// The regions for IfRegion themselves cannot have any arguments
|
||||
func @testInvalidIfRegionThenArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// expected-error @+1 {{then region cannot have any arguments}}
|
||||
// expected-error @+1 {{'tf.IfRegion' op region #0 should have no arguments}}
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
^bb(%arg_bb: tensor<2xf32>):
|
||||
%t = "tf.Abs"(%arg_bb) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -1072,7 +1072,7 @@ func @testInvalidIfRegionThenArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
|
||||
|
||||
func @testInvalidIfRegionElseArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// expected-error @+1 {{else region cannot have any arguments}}
|
||||
// expected-error @+1 {{'tf.IfRegion' op region #1 should have no arguments}}
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%neg) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
|
||||
@ -40,7 +40,7 @@ namespace TF {
|
||||
// LINT.IfChange(folding-policy)
|
||||
static bool ShouldBeFolded(Operation* inst) {
|
||||
constexpr int kSizeFactor = 2;
|
||||
constexpr int64_t kSizeThreshold = (1 << 20); // 128 KB
|
||||
constexpr int64_t kSizeThreshold = (1 << 21); // 256 KB
|
||||
bool has_unknown_shape = false;
|
||||
auto get_size = [&](TypeRange types) {
|
||||
int64_t size = 0;
|
||||
|
||||
@ -40,9 +40,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
@ -98,7 +95,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
context);
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -112,13 +109,11 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
@ -128,18 +123,17 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
|
||||
<< "': " << load_status;
|
||||
return nullptr;
|
||||
return load_status;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
@ -154,19 +148,18 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
|
||||
<< "': " << load_status;
|
||||
return nullptr;
|
||||
return load_status;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context,
|
||||
upgrade_legacy);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -180,7 +173,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
return module_or.status();
|
||||
}
|
||||
auto& module = module_or.ValueOrDie();
|
||||
std::srand(0);
|
||||
@ -215,7 +208,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
}
|
||||
}
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -23,15 +23,20 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// TODO(antiagainst): Directly manipulating files in library functions is not
|
||||
// a good idea. We should pass in a string/stream here.
|
||||
|
||||
// Converts a TensorFlow GraphDef stored in the file with the given
|
||||
// `input_filename` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -42,7 +47,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
|
||||
// Similar as the above function, but replaces all constant tensors
|
||||
// with randomly generated splat values.
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -54,7 +59,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
// Converts a TensorFlow SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
@ -62,7 +67,7 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
|
||||
@ -42,11 +42,13 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
|
||||
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
|
||||
MLIRContext* context) {
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
auto module_or = tensorflow::GraphdefToMlirTranslateFunction(
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, control_output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) return nullptr;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
||||
@ -54,11 +56,13 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
||||
|
||||
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, MLIRContext* context) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, control_output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) return nullptr;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
|
||||
|
||||
@ -112,19 +112,19 @@ int main(int argc, char** argv) {
|
||||
if (import_saved_model_object_graph) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else if (import_saved_model_signature_defs) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else {
|
||||
auto input = mlir::openInputFile(input_filename, &error_message);
|
||||
|
||||
|
||||
@ -129,20 +129,18 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
if (import_saved_model) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else if (import_saved_model_v1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
|
||||
@ -26,6 +26,28 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>,
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same
|
||||
// code), so only do a couple of basic checks.
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV2_noTraining
|
||||
func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
// CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
|
||||
%0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV2_training
|
||||
func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
// CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
|
||||
%0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
|
||||
// CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
|
||||
// CHECK: mhlo.constant
|
||||
// CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV3_noTraining
|
||||
func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
// CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
|
||||
@ -956,6 +978,41 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// SparseMatMul where one operand needs to be transposed and the other one not.
|
||||
//
|
||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
|
||||
// CHECK-SAME: permutation = dense<[1, 0]>
|
||||
// CHECK-SAME: -> tensor<4x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
// CHECK: }
|
||||
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
|
||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// SparseMatMul where one operand needs to be casted and the other one not.
|
||||
//
|
||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
|
||||
// CHECK-SAME: -> tensor<4x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
// CHECK: }
|
||||
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
|
||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPart op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -1531,23 +1531,23 @@ using ConvertFusedBatchNormGradV3Op =
|
||||
// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
|
||||
// HLO BatchNormInferenceOp, depending on the value of the 'is_training'
|
||||
// parameter.
|
||||
class ConvertFusedBatchNormV3Op
|
||||
: public OpRewritePattern<TF::FusedBatchNormV3Op> {
|
||||
template <typename FusedBatchNormOpT>
|
||||
class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::FusedBatchNormV3Op op,
|
||||
LogicalResult matchAndRewrite(FusedBatchNormOpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto feature_dim =
|
||||
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
|
||||
|
||||
auto input_type_tensor = op.x().getType().cast<TensorType>();
|
||||
auto input_type_tensor = op.x().getType().template cast<TensorType>();
|
||||
auto input_element_type = input_type_tensor.getElementType();
|
||||
|
||||
auto scale_type_tensor = op.scale().getType().cast<TensorType>();
|
||||
auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
|
||||
auto scale_element_type = scale_type_tensor.getElementType();
|
||||
|
||||
auto mean_type_tensor = op.mean().getType().cast<TensorType>();
|
||||
auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
|
||||
auto mean_element_type = mean_type_tensor.getElementType();
|
||||
// In the training case, dimensions of input tensors must be static.
|
||||
if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
|
||||
@ -1561,7 +1561,7 @@ class ConvertFusedBatchNormV3Op
|
||||
Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
|
||||
scale_element_type);
|
||||
TensorType bn_train_input_type_tensor =
|
||||
bn_train_input.getType().cast<TensorType>();
|
||||
bn_train_input.getType().template cast<TensorType>();
|
||||
|
||||
if (op.is_training()) {
|
||||
// Training case.
|
||||
@ -1643,17 +1643,25 @@ class ConvertFusedBatchNormV3Op
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
}
|
||||
|
||||
// TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are
|
||||
// currently marked as "reserved spaces 1 and 2". They are used to
|
||||
// pass the per-batch mean and variance to the gradiant. Here we
|
||||
// maintain the same behavior by setting them to the mean and variance
|
||||
// calculated by BatchNormTraining. Output 5 is unused; it doesn't
|
||||
// matter what we pass there.
|
||||
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
|
||||
/*batch_variance=*/corrected_variance,
|
||||
/*reserve_space_1=*/reserve_space_1,
|
||||
/*reserve_space_2=*/batch_variance,
|
||||
/*reserve_space_3=*/op.x()});
|
||||
if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
|
||||
// FusedBatchNormV2 expects 4 outputs.
|
||||
// Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
|
||||
// They are used to pass the per-batch mean and variance to the
|
||||
// gradiant. Here we maintain the same behavior by setting them to the
|
||||
// mean and variance calculated by BatchNormTraining.
|
||||
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
|
||||
/*batch_variance=*/corrected_variance,
|
||||
/*reserve_space_1=*/reserve_space_1,
|
||||
/*reserve_space_2=*/batch_variance});
|
||||
} else { // TF::FusedBatchNormV3Op
|
||||
// FusedBatchNormV3 expects a 5th output, but the output is unused; it
|
||||
// doesn't matter what we pass there.
|
||||
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
|
||||
/*batch_variance=*/corrected_variance,
|
||||
/*reserve_space_1=*/reserve_space_1,
|
||||
/*reserve_space_2=*/batch_variance,
|
||||
/*reserve_space_3=*/op.x()});
|
||||
}
|
||||
} else { // Inference case.
|
||||
auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
|
||||
op.getLoc(),
|
||||
@ -1670,31 +1678,45 @@ class ConvertFusedBatchNormV3Op
|
||||
// not used for inference. It doesn't matter what values we provide for
|
||||
// the last 5 results as long as they are of the same type. Forward
|
||||
// input mean and variance to output mean, variance, reserved_space_1 and
|
||||
// reserver_space_2. Create a constant tensor to forward to last
|
||||
// reserve_space_3 output.
|
||||
auto reserve_space_3_type = op.getResult(5).getType().cast<TensorType>();
|
||||
int num_elements = reserve_space_3_type.hasStaticShape()
|
||||
? reserve_space_3_type.getNumElements()
|
||||
: 0;
|
||||
auto const_attr_type = RankedTensorType::get(
|
||||
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
|
||||
|
||||
Value dummy_const = rewriter.create<ConstOp>(
|
||||
op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
|
||||
if (const_attr_type != reserve_space_3_type)
|
||||
dummy_const = rewriter.create<TensorCastOp>(
|
||||
op.getLoc(), reserve_space_3_type, dummy_const);
|
||||
rewriter.replaceOp(op, {/*y=*/y_out,
|
||||
/*batch_mean=*/op.mean(),
|
||||
/*batch_variance=*/op.variance(),
|
||||
/*reserve_space_1=*/op.mean(),
|
||||
/*reserve_space_2=*/op.variance(),
|
||||
/*reserve_space_3=*/dummy_const});
|
||||
// reserved_space_2.
|
||||
if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
|
||||
rewriter.replaceOp(op, {/*y=*/y_out,
|
||||
/*batch_mean=*/op.mean(),
|
||||
/*batch_variance=*/op.variance(),
|
||||
/*reserve_space_1=*/op.mean(),
|
||||
/*reserve_space_2=*/op.variance()});
|
||||
} else {
|
||||
// For FusedBatchNormV3Op, also create a constant tensor to forward to
|
||||
// last reserve_space_3 output.
|
||||
auto reserve_space_3_type =
|
||||
op.getResult(5).getType().template cast<TensorType>();
|
||||
int num_elements = reserve_space_3_type.hasStaticShape()
|
||||
? reserve_space_3_type.getNumElements()
|
||||
: 0;
|
||||
auto const_attr_type = RankedTensorType::get(
|
||||
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
|
||||
Value dummy_const = rewriter.create<ConstOp>(
|
||||
op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
|
||||
if (const_attr_type != reserve_space_3_type)
|
||||
dummy_const = rewriter.create<TensorCastOp>(
|
||||
op.getLoc(), reserve_space_3_type, dummy_const);
|
||||
rewriter.replaceOp(op, {/*y=*/y_out,
|
||||
/*batch_mean=*/op.mean(),
|
||||
/*batch_variance=*/op.variance(),
|
||||
/*reserve_space_1=*/op.mean(),
|
||||
/*reserve_space_2=*/op.variance(),
|
||||
/*reserve_space_3=*/dummy_const});
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
using ConvertFusedBatchNormV2Op =
|
||||
ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
|
||||
using ConvertFusedBatchNormV3Op =
|
||||
ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
|
||||
|
||||
using PaddingArray =
|
||||
std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
|
||||
|
||||
@ -5358,6 +5380,50 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
|
||||
// hints, since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Emits debug information which includes the number of ops of each type which
|
||||
// failed to legalize.
|
||||
void EmitLegalizationErrors(Operation *op,
|
||||
@ -5437,19 +5503,21 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
||||
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
|
||||
ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
|
||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertAvgPool2DGradOp,
|
||||
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
|
||||
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
|
||||
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
||||
ConvertAvgPoolOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp,
|
||||
ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp,
|
||||
ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp,
|
||||
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
||||
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
|
||||
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
|
||||
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
|
||||
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
|
||||
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertXlaShardingOp,
|
||||
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
|
||||
|
||||
@ -339,18 +339,6 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
|
||||
(TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
|
||||
/*precision_config=*/(NullArrayAttr))>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SparseMatMul op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Ignores the sparseness hints and translates tf.SparseMatMul to tf.MatMul
|
||||
// until we will have an implementation that can use the information.
|
||||
def SparseMatMulToMatMul : Pat<(TF_SparseMatMulOp $a, $b, $a_sparse, $b_sparse,
|
||||
$transpose_a, $transpose_b),
|
||||
(TF_MatMulOp $a, $b, $transpose_a,
|
||||
$transpose_b)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPart op pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -704,6 +704,7 @@ cc_library(
|
||||
srcs = ["mlir_bridge_pass.cc"],
|
||||
hdrs = ["mlir_bridge_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
||||
@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
// Skip function graphs as MlirBridgePass will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
|
||||
if (!IsEnabled(options.session_options->config)) {
|
||||
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge();
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
|
||||
@ -25,6 +25,8 @@ upper_tabs:
|
||||
path: /xla/operation_semantics
|
||||
- title: Shapes and layout
|
||||
path: /xla/shapes
|
||||
- title: Aliasing
|
||||
path: /xla/aliasing
|
||||
- title: Tiled layout
|
||||
path: /xla/tiled_layout
|
||||
- title: Use AOT compilation
|
||||
|
||||
73
tensorflow/compiler/xla/g3doc/aliasing.md
Normal file
73
tensorflow/compiler/xla/g3doc/aliasing.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Aliasing in XLA
|
||||
|
||||
This document describes the aliasing API for XLA: when building an XLA program,
|
||||
you can specify the desired aliasing between the input and output buffers.
|
||||
|
||||
## Defining aliasing at compile-time
|
||||
|
||||
For example, consider a trivial HLO module which simply adds `1` to its input:
|
||||
|
||||
```
|
||||
HloModule increment
|
||||
|
||||
ENTRY entry {
|
||||
%p = f32[] parameter(0)
|
||||
%c = f32[] constant(1)
|
||||
ROOT %out = f32[] add(%p, %c)
|
||||
}
|
||||
```
|
||||
|
||||
This module will allocate two 4-byte buffers: one for the input `%p`, and one
|
||||
for the output `%out`.
|
||||
|
||||
However, it is often desirable to perform the update in-place (for example, if
|
||||
in the frontend generating the expression the input variable is no longer alive
|
||||
after the computation, as in the increment `p++`).
|
||||
|
||||
To perform such an update efficiently, you can specify the input aliasing:
|
||||
|
||||
```
|
||||
HloModule increment, input_output_alias={ {}: 0 }
|
||||
|
||||
ENTRY entry {
|
||||
%p = f32[] parameter(0)
|
||||
%c = f32[] constant(1)
|
||||
ROOT %out = f32[] add(%p, %c)
|
||||
}
|
||||
```
|
||||
|
||||
The format specifies that the entire output (marked by `{}`) is aliased to the
|
||||
input parameter `0`.
|
||||
|
||||
See the
|
||||
[`XlaBuilder::SetUpAlias`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
|
||||
API to specify the aliasing programmatically.
|
||||
|
||||
## Defining aliasing at run-time
|
||||
|
||||
The aliasing defined in the previous step is specified during the _compilation_.
|
||||
During the execution, you can choose whether actually to donate the buffer using
|
||||
the
|
||||
[`LocalClient::RunAsync`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/local_client.h)
|
||||
API.
|
||||
|
||||
Input buffers to the program are wrapped in
|
||||
[`ExecutionInput`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/executable.h),
|
||||
which in turn contain a tree of `MaybeOwningDeviceMemory`. If memory is
|
||||
specified as _owning_ (ownership of the buffer is passed to the XLA runtime),
|
||||
the buffer is actually donated, and the update is executed in-place, as
|
||||
requested by the compile-time aliasing API.
|
||||
|
||||
If, however, the buffer which is aliased at compile time is _not_ donated at
|
||||
runtime, _copy-protection_ kicks in: an extra output buffer `O` is allocated,
|
||||
and the contents of the input buffer `P` which was meant to be aliased are
|
||||
copied into `O` (so effectively the program can execute as if the buffer `O` was
|
||||
donated at runtime).
|
||||
|
||||
## Frontend interop
|
||||
|
||||
### TF/XLA
|
||||
|
||||
In clusters of TensorFlow program compiled with XLA, all resource variable
|
||||
updates are aliased at compile time (the aliasing at runtime depends on whether
|
||||
anything else holds a reference to the resource variable tensor).
|
||||
@ -24,6 +24,7 @@ tf_proto_library_cc(
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
cc_grpc_version = 1,
|
||||
create_service = True,
|
||||
protodeps = [
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
],
|
||||
|
||||
@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
|
||||
custom_call_handler_);
|
||||
}
|
||||
|
||||
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
|
||||
HloInstruction* replace, HloInstruction* with) {
|
||||
CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||
CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||
for (auto& kv : dynamic_mapping_) {
|
||||
if (kv.second == replace) {
|
||||
kv.second = with;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
||||
HloInstruction* new_inst,
|
||||
const ShapeIndex& index) {
|
||||
|
||||
@ -68,6 +68,11 @@ class DynamicDimensionInference {
|
||||
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
|
||||
}
|
||||
|
||||
// For all tensors whose dynamic dimension is `replace`, replace them with
|
||||
// `with`.
|
||||
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
|
||||
HloInstruction* with);
|
||||
|
||||
friend class DynamicDimensionInferenceVisitor;
|
||||
|
||||
private:
|
||||
|
||||
@ -242,7 +242,6 @@ cc_library(
|
||||
deps = [
|
||||
":backend_configs_cc",
|
||||
":buffer_allocations",
|
||||
":cudnn_batchnorm_runner",
|
||||
":elemental_ir_emitter",
|
||||
":gpu_constants",
|
||||
":gpu_conv_runner",
|
||||
@ -267,6 +266,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service:while_loop_analysis",
|
||||
@ -282,7 +282,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:sort_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
||||
@ -31,13 +31,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
|
||||
const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type, int64 batch_size, int64 n,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kCholesky, hlo),
|
||||
PrimitiveType type, int64 batch_size, int64 n)
|
||||
: Thunk(Kind::kCholesky, thunk_info),
|
||||
uplo_(options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper),
|
||||
a_buffer_(a_buffer),
|
||||
@ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
info_buffer_(info_buffer),
|
||||
type_(type),
|
||||
batch_size_(batch_size),
|
||||
a_batch_stride_(n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
hlo->operand(0)->shape().element_type())),
|
||||
a_batch_stride_(
|
||||
n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
|
||||
n_(n) {}
|
||||
|
||||
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
@ -41,12 +41,11 @@ namespace gpu {
|
||||
class CholeskyThunk : public Thunk {
|
||||
public:
|
||||
static StatusOr<int64> ScratchBufferSize(int64 n);
|
||||
CholeskyThunk(const CholeskyOptions& options,
|
||||
CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type,
|
||||
int64 batch_size, int64 n, const HloInstruction* hlo);
|
||||
BufferAllocation::Slice info_buffer, PrimitiveType type,
|
||||
int64 batch_size, int64 n);
|
||||
|
||||
CholeskyThunk(const CholeskyThunk&) = delete;
|
||||
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
||||
|
||||
@ -218,14 +218,14 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||
} // anonymous namespace
|
||||
|
||||
CollectivePermuteThunk::CollectivePermuteThunk(
|
||||
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {}
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {}
|
||||
|
||||
Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction());
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
// Rendezvous with the threads for all other devices that are participating in
|
||||
// this CollectivePermute.
|
||||
|
||||
@ -26,9 +26,9 @@ namespace gpu {
|
||||
// Thunk that implements the collective-permute HLO.
|
||||
class CollectivePermuteThunk : public Thunk {
|
||||
public:
|
||||
CollectivePermuteThunk(const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr);
|
||||
CollectivePermuteThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
||||
@ -24,12 +24,14 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ConditionalThunk::ConditionalThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& branch_index_buffer_index,
|
||||
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
|
||||
std::vector<ThunkSequence> branch_thunk_sequences,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kConditional, hlo),
|
||||
branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED),
|
||||
std::vector<ThunkSequence> branch_thunk_sequences)
|
||||
: Thunk(Kind::kConditional, thunk_info),
|
||||
branch_index_is_bool_(
|
||||
thunk_info.hlo_instruction->operand(0)->shape().element_type() ==
|
||||
PRED),
|
||||
branch_index_buffer_index_(branch_index_buffer_index),
|
||||
branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(),
|
||||
branch_operand_buffer_indexes.end()) {
|
||||
@ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk(
|
||||
branch_thunks_.reserve(branch_thunk_sequences.size());
|
||||
for (auto& branch_thunk_sequence : branch_thunk_sequences) {
|
||||
branch_thunks_.emplace_back(
|
||||
new SequentialThunk(std::move(branch_thunk_sequence), nullptr));
|
||||
new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& profiler = *params.profiler;
|
||||
auto& stream = *params.stream;
|
||||
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
|
||||
// Copy the predicate value from device.
|
||||
int32 branch_index = -1;
|
||||
bool pred = false;
|
||||
|
||||
@ -43,10 +43,10 @@ namespace gpu {
|
||||
class ConditionalThunk : public Thunk {
|
||||
public:
|
||||
ConditionalThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& branch_index_buffer_index,
|
||||
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
|
||||
std::vector<ThunkSequence> branch_thunk_sequences,
|
||||
const HloInstruction* hlo);
|
||||
std::vector<ThunkSequence> branch_thunk_sequences);
|
||||
|
||||
ConditionalThunk(const ConditionalThunk&) = delete;
|
||||
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -30,16 +31,16 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ConvolutionThunk::ConvolutionThunk(
|
||||
const HloCustomCallInstruction* cudnn_call,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
|
||||
BufferAllocation::Slice tuple_result_slice)
|
||||
: Thunk(Kind::kConvolution, cudnn_call),
|
||||
cudnn_call_(cudnn_call),
|
||||
: Thunk(Kind::kConvolution, thunk_info),
|
||||
operand_buffers_(std::move(operand_slices)),
|
||||
result_buffer_(result_slice),
|
||||
scratch_buffer_(scratch_slice),
|
||||
tuple_result_buffer_(tuple_result_slice) {}
|
||||
tuple_result_buffer_(tuple_result_slice) {
|
||||
cudnn_call_ = Cast<HloCustomCallInstruction>(hlo_instruction());
|
||||
}
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
const auto& buffer_allocations = *params.buffer_allocations;
|
||||
@ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(scratch_buffer_);
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers),
|
||||
result_buffer, scratch, params.stream));
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
|
||||
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
|
||||
//
|
||||
// operand_slices should be in the same order as cudnn_call->operands().
|
||||
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
|
||||
ConvolutionThunk(ThunkInfo thunk_info,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice,
|
||||
BufferAllocation::Slice scratch_slice,
|
||||
|
||||
@ -22,10 +22,9 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
HostToDeviceCopyThunk::HostToDeviceCopyThunk(
|
||||
const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kCopy, hlo_instruction),
|
||||
ThunkInfo thunk_info, const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
|
||||
: Thunk(Kind::kCopy, thunk_info),
|
||||
source_address_(source_address),
|
||||
destination_buffer_(destination_buffer),
|
||||
mem_size_(mem_size) {}
|
||||
@ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kCopy, hlo_instruction),
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
|
||||
: Thunk(Kind::kCopy, thunk_info),
|
||||
source_buffer_(source_buffer),
|
||||
destination_buffer_(destination_buffer),
|
||||
mem_size_(mem_size) {}
|
||||
@ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase source_data =
|
||||
params.buffer_allocations->GetDeviceAddress(source_buffer_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemcpy(&destination_data, source_data, mem_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk {
|
||||
// Constructs a CopyThunk that copies host data from `source_address` to the
|
||||
// device buffer `destination_buffer`. `mem_size` is the size of the data in
|
||||
// bytes.
|
||||
HostToDeviceCopyThunk(const void* source_address,
|
||||
HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
uint64 mem_size, const HloInstruction* hlo_instruction);
|
||||
uint64 mem_size);
|
||||
|
||||
HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete;
|
||||
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
|
||||
@ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk {
|
||||
// Constructs a CopyThunk that copies host data from `source_buffer` to the
|
||||
// device buffer `destination_buffer`. `mem_size` is the size of the data in
|
||||
// bytes.
|
||||
DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer,
|
||||
DeviceToDeviceCopyThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction);
|
||||
uint64 mem_size);
|
||||
|
||||
DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete;
|
||||
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
|
||||
|
||||
@ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
|
||||
} // namespace
|
||||
|
||||
CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
|
||||
const BufferAllocation::Slice& output)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
@ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
||||
epsilon_(epsilon),
|
||||
feature_index_(feature_index),
|
||||
output_(output) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(),
|
||||
kCudnnBatchNormForwardInferenceCallTarget);
|
||||
@ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||
const ExecuteParams& params) {
|
||||
auto& buffer_allocations = *params.buffer_allocations;
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
se::DeviceMemoryBase output_base =
|
||||
buffer_allocations.GetDeviceAddress(output_);
|
||||
se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
|
||||
@ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||
}
|
||||
|
||||
CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||
float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
@ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
output_mean_(output_mean),
|
||||
output_inv_stddev_(output_inv_stddev),
|
||||
output_tuple_(output_tuple) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
|
||||
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
|
||||
@ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
|
||||
se::DeviceMemory<float> null_device_ptr(nullptr);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
auto& stream = *params.stream;
|
||||
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
|
||||
hlo_instruction(), operand, output_data, output_mean, output_inv_stddev,
|
||||
@ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
}
|
||||
|
||||
CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
const BufferAllocation::Slice& grad_output, float epsilon,
|
||||
int64 feature_index, const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
mean_(mean),
|
||||
@ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
output_grad_scale_(output_grad_scale),
|
||||
output_grad_offset_(output_grad_offset),
|
||||
output_tuple_(output_tuple) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
|
||||
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
|
||||
@ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
buffer_allocations.GetDeviceAddress(output_grad_offset_));
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
se::Stream* stream = params.stream;
|
||||
TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
|
||||
hlo_instruction(), operand, output_grad_data, grad_output,
|
||||
|
||||
@ -46,14 +46,14 @@ namespace gpu {
|
||||
|
||||
class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand,
|
||||
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& variance,
|
||||
float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output,
|
||||
const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output);
|
||||
|
||||
CudnnBatchNormForwardInferenceThunk(
|
||||
const CudnnBatchNormForwardInferenceThunk&) = delete;
|
||||
@ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
||||
class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormForwardTrainingThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
|
||||
CudnnBatchNormForwardTrainingThunk(
|
||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||
@ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
|
||||
class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand,
|
||||
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
@ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple,
|
||||
const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
|
||||
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
|
||||
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
|
||||
|
||||
@ -22,15 +22,15 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
CustomCallThunk::CustomCallThunk(
|
||||
void* call_target,
|
||||
ThunkInfo thunk_info, void* call_target,
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(Thunk::kCustomCall, instr),
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
|
||||
: Thunk(Thunk::kCustomCall, thunk_info),
|
||||
call_target_(call_target),
|
||||
operand_slices_(std::move(operand_slices)),
|
||||
result_slices_(std::move(result_slices)),
|
||||
opaque_(std::move(opaque)) {
|
||||
const HloInstruction* instr = hlo_instruction();
|
||||
CHECK_EQ(instr->operand_count(), operand_slices_.size());
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
const auto& s1 = operand_slices_[i].shape();
|
||||
|
||||
@ -39,10 +39,9 @@ namespace gpu {
|
||||
class CustomCallThunk : public Thunk {
|
||||
public:
|
||||
CustomCallThunk(
|
||||
void* call_target,
|
||||
ThunkInfo thunk_info, void* call_target,
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
|
||||
const HloInstruction* instr);
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
||||
@ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
struct NcclAllReduceThunk::AuxData {};
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
replica_count_(replica_count),
|
||||
buffers_(std::move(buffers)) {}
|
||||
|
||||
|
||||
@ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
|
||||
absl::Span<const int64> fft_length,
|
||||
const BufferAllocation::Slice& input_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kFft, hlo),
|
||||
const Shape& input_shape, const Shape& output_shape)
|
||||
: Thunk(Kind::kFft, thunk_info),
|
||||
fft_type_(
|
||||
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
|
||||
input_shape.element_type() == C128)),
|
||||
@ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.memory_allocator());
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
if (fft_plan_ == nullptr) {
|
||||
const int64 fft_rank = fft_length_.size();
|
||||
CHECK_LE(fft_rank, 3);
|
||||
|
||||
@ -62,11 +62,11 @@ class FftThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a thunk for launching an FFT on a stream.
|
||||
// Semantics of null hlo_instruction argument are as in Thunk.
|
||||
FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
FftThunk(ThunkInfo thunk_info, FftType fft_type,
|
||||
absl::Span<const int64> fft_length,
|
||||
const BufferAllocation::Slice& input_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const HloInstruction* hlo);
|
||||
const Shape& input_shape, const Shape& output_shape);
|
||||
|
||||
FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
|
||||
FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_
|
||||
|
||||
@ -23,16 +23,15 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ForThunk::ForThunk(const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kWhile, hlo),
|
||||
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
||||
: Thunk(Kind::kWhile, thunk_info),
|
||||
loop_limit_(loop_limit),
|
||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
|
||||
// constructor because this SequentialThunk is logically "part of"
|
||||
// this ForThunk, and shouldn't be profiled separately from it.
|
||||
std::move(*body_thunk_sequence), nullptr)) {}
|
||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
||||
|
||||
void ForThunk::ComputeAnnotations() {
|
||||
Thunk::ComputeAnnotations();
|
||||
@ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
||||
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
for (int64 i = 0; i < loop_limit_; ++i) {
|
||||
params.profiler->StartHloComputation();
|
||||
// Invoke loop body thunk sequence.
|
||||
|
||||
@ -31,9 +31,8 @@ namespace gpu {
|
||||
// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
|
||||
class ForThunk : public Thunk {
|
||||
public:
|
||||
ForThunk(const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo);
|
||||
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
||||
ForThunk(const ForThunk&) = delete;
|
||||
ForThunk& operator=(const ForThunk&) = delete;
|
||||
|
||||
|
||||
@ -132,6 +132,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer,
|
||||
stream,
|
||||
/*implements_whole_instruction=*/true,
|
||||
/*profile_index=*/-1,
|
||||
/*profiler=*/nullptr,
|
||||
/*profile_result=*/&profile_result, algorithm)
|
||||
.ok());
|
||||
|
||||
@ -33,13 +33,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
|
||||
GemmThunk::GemmThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice &lhs_buffer,
|
||||
const BufferAllocation::Slice &rhs_buffer,
|
||||
const BufferAllocation::Slice &output_buffer,
|
||||
bool implements_whole_instruction,
|
||||
const HloInstruction *hlo_instruction,
|
||||
const GemmBackendConfig &backend_config)
|
||||
: Thunk(Kind::kGemm, hlo_instruction),
|
||||
: Thunk(Kind::kGemm, thunk_info),
|
||||
lhs_buffer_(lhs_buffer),
|
||||
rhs_buffer_(rhs_buffer),
|
||||
output_buffer_(output_buffer),
|
||||
@ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) {
|
||||
se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
|
||||
return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
|
||||
output_data, params.stream, implements_whole_instruction_,
|
||||
params.profiler);
|
||||
profile_index(), params.profiler);
|
||||
}
|
||||
|
||||
// This struct contains the metadata of a matrix, e.g., its base address and
|
||||
@ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm,
|
||||
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream *stream,
|
||||
bool implements_whole_instruction,
|
||||
absl::optional<int64> profile_index,
|
||||
HloExecutionProfiler *profiler,
|
||||
se::blas::ProfileResult *profile_result,
|
||||
absl::optional<se::blas::AlgorithmType> algorithm) {
|
||||
@ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm,
|
||||
rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
|
||||
std::unique_ptr<ScopedInstructionProfiler> op_profiler =
|
||||
profiler ? profiler->MakeScopedInstructionProfiler(
|
||||
implements_whole_instruction ? gemm : nullptr)
|
||||
implements_whole_instruction ? profile_index : -1)
|
||||
: nullptr;
|
||||
|
||||
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
|
||||
|
||||
@ -39,11 +39,10 @@ class GemmThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
|
||||
// BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
|
||||
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer,
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
bool implements_whole_instruction,
|
||||
const HloInstruction* hlo_instruction,
|
||||
const GemmBackendConfig& backend_config);
|
||||
|
||||
GemmThunk(const GemmThunk&) = delete;
|
||||
@ -72,7 +71,8 @@ Status RunGemm(
|
||||
const HloInstruction* gemm, const GemmBackendConfig& backend_config,
|
||||
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream* stream,
|
||||
bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr,
|
||||
bool implements_whole_instruction, absl::optional<int64> profile_index,
|
||||
HloExecutionProfiler* profiler = nullptr,
|
||||
se::blas::ProfileResult* profile_result = nullptr,
|
||||
absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt);
|
||||
|
||||
|
||||
@ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
|
||||
int pointer_size, std::unique_ptr<llvm::Module>* llvm_module,
|
||||
int pointer_size, const HloProfileIndexMap* profile_index_map,
|
||||
std::unique_ptr<llvm::Module>* llvm_module,
|
||||
std::unique_ptr<BufferAssignment>* buffer_assignment,
|
||||
std::unique_ptr<ThunkSchedule>* thunk_schedule) {
|
||||
*llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
|
||||
@ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
|
||||
cuda_compute_capability, llvm_module->get());
|
||||
cuda_compute_capability, profile_index_map, llvm_module->get());
|
||||
|
||||
HloComputation* entry_computation = hlo_module->entry_computation();
|
||||
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
|
||||
@ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
// not all explicitly checked, but at least we can document them here:
|
||||
// * The entry HloComputation shall not have dead code (all reachable from
|
||||
// ROOT).
|
||||
// * For each visit of HloInstruction, either none or one Thunk will be
|
||||
// returned.
|
||||
// * The visited instructions are all instructions in the entry
|
||||
// computation.
|
||||
// * For each visit of these HloInstructions, either none or one Thunk
|
||||
// will be returned.
|
||||
// * If there is a thunk returned, thunk->hlo_instruction() equals the
|
||||
// input HloInstruction*.
|
||||
// * A returned thunk may contain other sub-thunks. A sub-thunk may or may
|
||||
// not have an associated hlo_instruction().
|
||||
TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
|
||||
if (!thunks->empty()) {
|
||||
auto thunk = std::move(thunks->front());
|
||||
@ -603,6 +608,25 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
return cuda_compute_capability;
|
||||
}();
|
||||
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||
|
||||
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
|
||||
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
|
||||
cost_analysis.set_bytes_per_second(
|
||||
stream_exec->GetDeviceDescription().memory_bandwidth());
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
|
||||
VLOG(1) << "HLO memory read+written: "
|
||||
<< tensorflow::strings::HumanReadableNumBytes(
|
||||
cost_analysis.bytes_accessed());
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
|
||||
profile_printer =
|
||||
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
|
||||
module->entry_computation()->name());
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> llvm_module;
|
||||
std::unique_ptr<BufferAssignment> buffer_assignment;
|
||||
std::unique_ptr<ThunkSchedule> thunk_schedule;
|
||||
@ -610,8 +634,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
|
||||
module.get(), &llvm_context, target_triple_, data_layout_,
|
||||
stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
|
||||
GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment,
|
||||
&thunk_schedule));
|
||||
GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
|
||||
&buffer_assignment, &thunk_schedule));
|
||||
|
||||
if (user_pre_optimization_hook_) {
|
||||
user_pre_optimization_hook_(*llvm_module);
|
||||
@ -653,25 +677,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||
|
||||
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
|
||||
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
|
||||
cost_analysis.set_bytes_per_second(
|
||||
stream_exec->GetDeviceDescription().memory_bandwidth());
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
|
||||
VLOG(1) << "HLO memory read+written: "
|
||||
<< tensorflow::strings::HumanReadableNumBytes(
|
||||
cost_analysis.bytes_accessed());
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
|
||||
profile_printer =
|
||||
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
|
||||
module->entry_computation()->name());
|
||||
}
|
||||
}
|
||||
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
backend_result.first, backend_result.second, gpu_version,
|
||||
std::move(thunk_schedule), std::move(module),
|
||||
@ -709,7 +714,8 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
|
||||
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
|
||||
hlo_module, llvm_context, target_triple, data_layout, platform_name,
|
||||
gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
|
||||
pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule));
|
||||
pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
|
||||
&buffer_assignment, &thunk_schedule));
|
||||
return llvm_module;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() {
|
||||
}
|
||||
}
|
||||
|
||||
void HloExecutionProfiler::FinishHloInstruction(
|
||||
const HloInstruction* hlo_instruction) {
|
||||
void HloExecutionProfiler::FinishHloInstruction(size_t index) {
|
||||
if (do_profile_) {
|
||||
hlo_instructions_.erase(hlo_instruction);
|
||||
profile_->SetCyclesTakenBy(
|
||||
hlo_instruction,
|
||||
GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
|
||||
indices_.erase(index);
|
||||
profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_,
|
||||
stream_, clock_rate_ghz_));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ScopedInstructionProfiler>
|
||||
HloExecutionProfiler::MakeScopedInstructionProfiler(
|
||||
const HloInstruction* hlo_instruction) {
|
||||
if (do_profile_ && hlo_instruction != nullptr) {
|
||||
absl::optional<int64> index) {
|
||||
if (do_profile_ && index.has_value()) {
|
||||
// Make sure that we are not already measuring the time for the same
|
||||
// 'hlo_instruction'.
|
||||
CHECK(hlo_instructions_.insert(hlo_instruction).second)
|
||||
<< hlo_instruction->name();
|
||||
// instruction.
|
||||
// TODO(timshen): provide more useful printout.
|
||||
CHECK(indices_.insert(*index).second) << *index;
|
||||
}
|
||||
return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
|
||||
return absl::make_unique<ScopedInstructionProfiler>(this, index);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
@ -58,14 +57,17 @@ class HloExecutionProfiler {
|
||||
void StartHloInstruction();
|
||||
|
||||
// If profiling is enabled, stops the per-operation timer and records the time
|
||||
// that the hlo_instruction took to execute in the profile.
|
||||
void FinishHloInstruction(const HloInstruction* hlo_instruction);
|
||||
// that at `profile_index`. Profile indices can be looked up from
|
||||
// HloProfileIndexMap.
|
||||
void FinishHloInstruction(size_t profile_index);
|
||||
|
||||
// Returns a ScopedInstructionProfiler and triggers a call to
|
||||
// StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
|
||||
// out of scope, it triggers a call to FinishHloInstruction().
|
||||
//
|
||||
// If profile_index < 0, it results in a no-op.
|
||||
std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
|
||||
const HloInstruction* hlo_instruction);
|
||||
absl::optional<int64> profile_index);
|
||||
|
||||
private:
|
||||
const bool do_profile_;
|
||||
@ -77,7 +79,7 @@ class HloExecutionProfiler {
|
||||
std::stack<std::unique_ptr<se::Timer>> timers_;
|
||||
// Contains the HLO instructions for which we are currently measuring the
|
||||
// time.
|
||||
std::unordered_set<const HloInstruction*> hlo_instructions_;
|
||||
std::unordered_set<size_t> indices_;
|
||||
bool finished_execution_ = false;
|
||||
};
|
||||
|
||||
@ -87,21 +89,21 @@ class HloExecutionProfiler {
|
||||
class ScopedInstructionProfiler {
|
||||
public:
|
||||
ScopedInstructionProfiler(HloExecutionProfiler* profiler,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: profiler_(profiler), hlo_instruction_(hlo_instruction) {
|
||||
if (hlo_instruction != nullptr) {
|
||||
absl::optional<int64> index)
|
||||
: profiler_(profiler), index_(index) {
|
||||
if (index_.has_value()) {
|
||||
profiler->StartHloInstruction();
|
||||
}
|
||||
}
|
||||
~ScopedInstructionProfiler() {
|
||||
if (hlo_instruction_ != nullptr) {
|
||||
profiler_->FinishHloInstruction(hlo_instruction_);
|
||||
if (index_.has_value()) {
|
||||
profiler_->FinishHloInstruction(*index_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
HloExecutionProfiler* profiler_;
|
||||
const HloInstruction* hlo_instruction_;
|
||||
absl::optional<int64> index_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
@ -23,9 +23,9 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
InfeedThunk::InfeedThunk(
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
|
||||
ThunkInfo thunk_info,
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
|
||||
: Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {}
|
||||
|
||||
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& stream = *params.stream;
|
||||
@ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
ShapeTree<InfeedBuffer> infeed_buffers =
|
||||
GetOrCreateInfeedManager()->BlockingGetNextDestination();
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ class InfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a InfeedThunk that copies data from the on-device
|
||||
// infeed queue into the buffers in the given shape tree.
|
||||
InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
|
||||
const HloInstruction* hlo_instruction);
|
||||
InfeedThunk(ThunkInfo thunk_info,
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices);
|
||||
|
||||
InfeedThunk(const InfeedThunk&) = delete;
|
||||
InfeedThunk& operator=(const InfeedThunk&) = delete;
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
|
||||
namespace xla {
|
||||
@ -33,12 +34,13 @@ class IrEmitterContext {
|
||||
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
|
||||
std::string platform_name, GpuDeviceInfo gpu_device_info,
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
llvm::Module* llvm_module)
|
||||
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
|
||||
: hlo_module_(hlo_module),
|
||||
buffer_assignment_(buffer_assignment),
|
||||
platform_name_(std::move(platform_name)),
|
||||
gpu_device_info_(gpu_device_info),
|
||||
cuda_compute_capability_(cuda_compute_capability),
|
||||
profile_index_map_(profile_index_map),
|
||||
llvm_module_(llvm_module) {}
|
||||
// Disallow copy and assign.
|
||||
IrEmitterContext(const IrEmitterContext&) = delete;
|
||||
@ -54,6 +56,7 @@ class IrEmitterContext {
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability() const {
|
||||
return cuda_compute_capability_;
|
||||
}
|
||||
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
|
||||
llvm::Module* llvm_module() { return llvm_module_; }
|
||||
NameUniquer* name_uniquer() { return &name_uniquer_; }
|
||||
|
||||
@ -63,6 +66,7 @@ class IrEmitterContext {
|
||||
std::string platform_name_;
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability_;
|
||||
const HloProfileIndexMap* profile_index_map_;
|
||||
llvm::Module* llvm_module_;
|
||||
NameUniquer name_uniquer_;
|
||||
};
|
||||
|
||||
@ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
/*updates_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2))));
|
||||
}
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(fusion), std::move(thunks)));
|
||||
return Status::OK();
|
||||
}
|
||||
// In the case of root tuple, it can be either reduce or slice input
|
||||
@ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
||||
auto destination_buffer = GetAllocationSlice(*copy);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
GetThunkInfo(copy),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/
|
||||
ByteSizeOf(copy->operand(0)->shape()), copy));
|
||||
ByteSizeOf(copy->operand(0)->shape())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
|
||||
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
|
||||
}
|
||||
AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
|
||||
GetThunkInfo(tuple), tuple_element_buffers,
|
||||
GetAllocationSlice(*tuple)));
|
||||
return Status::OK();
|
||||
}
|
||||
AddThunkToThunkSequence(
|
||||
@ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
thunks.push_back(BuildKernelThunk(select_and_scatter,
|
||||
/*implements_whole_instruction=*/false));
|
||||
std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
|
||||
absl::make_unique<SequentialThunk>(GetThunkInfo(select_and_scatter),
|
||||
std::move(thunks));
|
||||
|
||||
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
|
||||
if (window_util::HasDilation(window)) {
|
||||
@ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
auto destination_buffer = GetAllocationSlice(*scatter);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
|
||||
/*hlo_instruction=*/nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
|
||||
}
|
||||
|
||||
thunks.push_back(
|
||||
@ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(scatter), std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
// TODO(b/26783907): Figure out why we never seem to share buffers for
|
||||
// key/value sort.
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/source_address,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
|
||||
nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
|
||||
}
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(sort), std::move(thunks)));
|
||||
if (sort->operand_count() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
@ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
|
||||
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
|
||||
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)),
|
||||
GetAllocationSlice(*hlo)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||
}
|
||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||
GetThunkInfo(crs),
|
||||
/*replica_count=*/hlo_module_config_.replica_count(),
|
||||
/*buffers=*/std::move(buffers), crs);
|
||||
/*buffers=*/std::move(buffers));
|
||||
if (crs->shape().IsTuple()) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(std::move(all_reduce_thunk));
|
||||
thunks.push_back(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
|
||||
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(crs), std::move(thunks)));
|
||||
} else {
|
||||
AddThunkToThunkSequence(std::move(all_reduce_thunk));
|
||||
}
|
||||
@ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
CHECK(crs->operand(0)->shape().IsArray())
|
||||
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
GetThunkInfo(crs),
|
||||
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
|
||||
/*destination_buffer=*/GetAllocationSlice(*crs),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
.GetUniqueSlice(crs, {i})
|
||||
.ValueOrDie());
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
|
||||
/*destination_buffer=*/tuple_element_buffers.back(),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape())));
|
||||
}
|
||||
|
||||
// Output a tuple of the buffers above.
|
||||
thunks.push_back(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
|
||||
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
|
||||
absl::make_unique<SequentialThunk>(GetThunkInfo(crs), std::move(thunks)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1787,8 +1794,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
}
|
||||
|
||||
return absl::make_unique<KernelThunk>(
|
||||
non_constant_buffers, std::string(kernel->getName()),
|
||||
implements_whole_instruction ? inst : nullptr);
|
||||
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
|
||||
non_constant_buffers, std::string(kernel->getName()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
@ -1838,8 +1845,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
absl::Span<const uint8> literal_bytes(
|
||||
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
|
||||
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
|
||||
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
|
||||
nullptr)};
|
||||
return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(),
|
||||
GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
|
||||
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
|
||||
@ -1857,7 +1864,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
}
|
||||
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
|
||||
return {absl::make_unique<Memset32BitValueThunk>(
|
||||
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
|
||||
Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
|
||||
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
|
||||
@ -1868,7 +1875,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
uint32 word;
|
||||
memcpy(&word, literal_bytes.data(), sizeof(word));
|
||||
return {absl::make_unique<Memset32BitValueThunk>(
|
||||
word, GetAllocationSlice(*hlo, index), nullptr)};
|
||||
Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
}
|
||||
|
||||
@ -2014,9 +2021,10 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return absl::make_unique<WhileThunk>(
|
||||
GetThunkInfo(hlo),
|
||||
GetAllocationSlice(*condition->root_instruction()), // cond result
|
||||
ir_emitter_condition.ConsumeThunkSequence(),
|
||||
ir_emitter_body.ConsumeThunkSequence(), hlo);
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
|
||||
@ -2031,8 +2039,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return absl::make_unique<ForThunk>(
|
||||
loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
|
||||
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
|
||||
@ -2054,8 +2062,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
|
||||
}
|
||||
|
||||
return absl::make_unique<ConditionalThunk>(
|
||||
GetAllocationSlice(*hlo->operand(0)), branch_operands,
|
||||
std::move(branch_thunks), hlo);
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
|
||||
std::move(branch_thunks));
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
@ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
thunks.push_back(std::move(kernel_thunk));
|
||||
auto sequential_thunk =
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
|
||||
auto sequential_thunk = absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(unnested_hlo), std::move(thunks));
|
||||
AddThunkToThunkSequence(std::move(sequential_thunk));
|
||||
|
||||
return Status::OK();
|
||||
@ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
|
||||
return emit_status;
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
|
||||
if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
|
||||
info.profile_index.emplace(
|
||||
static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Returns the last generated thunk.
|
||||
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
|
||||
|
||||
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
|
||||
|
||||
// The thunk sequence this IrEmitter generates for the input computation.
|
||||
ThunkSequence thunk_sequence_;
|
||||
|
||||
|
||||
@ -33,10 +33,10 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kKernel, hlo_instruction),
|
||||
KernelThunk::KernelThunk(ThunkInfo thunk_info,
|
||||
absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name)
|
||||
: Thunk(Kind::kKernel, thunk_info),
|
||||
args_(args.begin(), args.end()),
|
||||
kernel_name_(kernel_name) {}
|
||||
|
||||
@ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
return ExecuteKernelOnStream(*kernel, buffer_args,
|
||||
launch_dimensions.threads_per_block(),
|
||||
launch_dimensions.block_count(), params.stream);
|
||||
|
||||
@ -47,8 +47,9 @@ class KernelThunk : public Thunk {
|
||||
// Constructs a thunk for the given kernel.
|
||||
//
|
||||
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
|
||||
KernelThunk(absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name, const HloInstruction* hlo_instruction);
|
||||
KernelThunk(ThunkInfo thunk_info,
|
||||
absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name);
|
||||
KernelThunk(const KernelThunk&) = delete;
|
||||
KernelThunk& operator=(const KernelThunk&) = delete;
|
||||
~KernelThunk() override = default;
|
||||
|
||||
@ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase dest_data =
|
||||
params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemZero(&dest_data, dest_data.size());
|
||||
return Status::OK();
|
||||
}
|
||||
@ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase dest_data =
|
||||
params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemset32(&dest_data, value_, dest_data.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -32,9 +32,9 @@ namespace gpu {
|
||||
// Thunk that zeroes out a given chunk of memory.
|
||||
class MemzeroThunk : public Thunk {
|
||||
public:
|
||||
explicit MemzeroThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
|
||||
explicit MemzeroThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kMemzero, thunk_info), dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
@ -46,10 +46,11 @@ class MemzeroThunk : public Thunk {
|
||||
// destination chunk must have size divisible by 32 bits.
|
||||
class Memset32BitValueThunk : public Thunk {
|
||||
public:
|
||||
explicit Memset32BitValueThunk(uint32 value,
|
||||
const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
|
||||
explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kMemset32BitValue, thunk_info),
|
||||
value_(value),
|
||||
dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
||||
@ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
replica_count_(replica_count),
|
||||
buffers_(std::move(buffers)),
|
||||
aux_data_(absl::make_unique<AuxData>()) {
|
||||
@ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(1) << "Starting NcclAllReduceThunk.";
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
|
||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||
|
||||
@ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk {
|
||||
BufferAllocation::Slice source_buffer;
|
||||
BufferAllocation::Slice destination_buffer;
|
||||
};
|
||||
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
|
||||
const HloInstruction* all_reduce);
|
||||
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<Buffer> buffers);
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
@ -23,9 +23,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kOutfeed, hlo_instruction),
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices)
|
||||
: Thunk(Kind::kOutfeed, thunk_info),
|
||||
outfeed_slices_(std::move(outfeed_slices)) {}
|
||||
|
||||
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
@ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
||||
@ -32,8 +32,8 @@ class OutfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a OutfeedThunk that copies data to the host-side
|
||||
// outfeed queue from the buffers in the given shape tree.
|
||||
OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
||||
const HloInstruction* hlo_instruction);
|
||||
OutfeedThunk(ThunkInfo thunk_info,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices);
|
||||
|
||||
OutfeedThunk(const OutfeedThunk&) = delete;
|
||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||
|
||||
@ -18,13 +18,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(Kind::kReplicaId, instr), dest_(dest) {}
|
||||
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
|
||||
|
||||
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
|
||||
@ -26,8 +26,7 @@ namespace gpu {
|
||||
// Thunk that implements the ReplicaId HLO.
|
||||
class ReplicaIdThunk : public Thunk {
|
||||
public:
|
||||
ReplicaIdThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr);
|
||||
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
||||
@ -24,9 +24,9 @@ namespace gpu {
|
||||
|
||||
using ::tensorflow::profiler::ScopedAnnotation;
|
||||
|
||||
SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
|
||||
SequentialThunk::SequentialThunk(ThunkInfo thunk_info,
|
||||
std::vector<std::unique_ptr<Thunk>> thunks)
|
||||
: Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {}
|
||||
|
||||
void SequentialThunk::ComputeAnnotations() {
|
||||
for (const auto& thunk : thunks_) {
|
||||
@ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
||||
|
||||
Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
for (const auto& thunk : thunks_) {
|
||||
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
|
||||
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params));
|
||||
|
||||
@ -32,8 +32,8 @@ namespace gpu {
|
||||
// require multiple kernel launches or library calls.
|
||||
class SequentialThunk : public Thunk {
|
||||
public:
|
||||
SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
|
||||
const HloInstruction* hlo);
|
||||
SequentialThunk(ThunkInfo thunk_info,
|
||||
std::vector<std::unique_ptr<Thunk>> thunks);
|
||||
SequentialThunk(const SequentialThunk&) = delete;
|
||||
SequentialThunk& operator=(const SequentialThunk&) = delete;
|
||||
|
||||
|
||||
@ -68,13 +68,21 @@ class Thunk {
|
||||
kWhile,
|
||||
};
|
||||
|
||||
struct ThunkInfo {
|
||||
const HloInstruction* hlo_instruction = nullptr;
|
||||
absl::optional<int64> profile_index;
|
||||
// TODO(timshen): Remove hlo_instruction and add name(),
|
||||
// profile_annotation() here.
|
||||
};
|
||||
|
||||
// The hlo_instruction argument is meant to be the instruction this thunk was
|
||||
// generated from, but Thunk never uses this argument other than to save it
|
||||
// to Thunk::hlo_instruction, so it can be null.
|
||||
explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
|
||||
explicit Thunk(Kind kind, ThunkInfo thunk_info)
|
||||
: kind_(kind),
|
||||
hlo_instruction_(hlo_instruction),
|
||||
name_(hlo_instruction_ ? hlo_instruction_->name() : "") {}
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
name_(hlo_instruction_ ? hlo_instruction_->name() : ""),
|
||||
profile_index_(thunk_info.profile_index) {}
|
||||
virtual ~Thunk() {}
|
||||
Thunk(const Thunk&) = delete;
|
||||
Thunk& operator=(const Thunk&) = delete;
|
||||
@ -128,6 +136,8 @@ class Thunk {
|
||||
protected:
|
||||
const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
|
||||
|
||||
absl::optional<int64> profile_index() const { return profile_index_; }
|
||||
|
||||
const HloModuleConfig& GetModuleConfig() const {
|
||||
return hlo_instruction()->GetModule()->config();
|
||||
}
|
||||
@ -146,8 +156,12 @@ class Thunk {
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
|
||||
// Will be removed in the future, as Thunk is migrating away from the
|
||||
// monolithic HloInstruction.
|
||||
const HloInstruction* hlo_instruction_;
|
||||
std::string name_;
|
||||
absl::optional<int64> profile_index_;
|
||||
string profile_annotation_;
|
||||
};
|
||||
|
||||
|
||||
@ -40,11 +40,11 @@ namespace gpu {
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
|
||||
const HloInstruction* operand = inst->operand(0);
|
||||
return absl::make_unique<FftThunk>(
|
||||
inst->fft_type(), inst->fft_length(),
|
||||
context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(),
|
||||
/*input_buffer=*/GetAllocationSlice(*operand),
|
||||
/*output_buffer=*/GetAllocationSlice(*inst),
|
||||
/*input_shape=*/operand->shape(),
|
||||
/*output_shape=*/inst->shape(), inst);
|
||||
/*output_shape=*/inst->shape());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
|
||||
@ -63,11 +63,11 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
|
||||
: n * n * elem_size;
|
||||
int64 b_batch_stride = m * n * elem_size;
|
||||
return absl::make_unique<TriangularSolveThunk>(
|
||||
inst->triangular_solve_options(),
|
||||
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
|
||||
/*a_input_buffer=*/GetAllocationSlice(*a),
|
||||
/*b_input_buffer=*/GetAllocationSlice(*inst),
|
||||
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
|
||||
b_batch_stride, inst);
|
||||
b_batch_stride);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
@ -86,24 +86,27 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_buffer=*/GetAllocationSlice(*bias),
|
||||
/*destination_buffer=*/GetAllocationSlice(*inst),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape())));
|
||||
thunks.push_back(absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/false, inst,
|
||||
std::move(gemm_config)));
|
||||
return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
|
||||
/*implements_whole_instruction=*/false, std::move(gemm_config)));
|
||||
return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(thunks));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/true, inst, std::move(gemm_config));
|
||||
/*implements_whole_instruction=*/true, std::move(gemm_config));
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
||||
@ -115,7 +118,7 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
||||
[&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
|
||||
*slice = GetAllocationSlice(*inst, index);
|
||||
});
|
||||
return absl::make_unique<InfeedThunk>(slices, inst);
|
||||
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst), slices);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
@ -130,7 +133,8 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
*slice = status_or_slice.ValueOrDie();
|
||||
}
|
||||
});
|
||||
return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
|
||||
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(slices));
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
@ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*variance=*/GetAllocationSlice(*custom_call->operand(4)),
|
||||
/*epsilon=*/epsilon_value,
|
||||
/*feature_index=*/feature_index_value,
|
||||
/*output=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto output_inv_stddev = GetAllocationSlice(*custom_call, {2});
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*output_data=*/output_data,
|
||||
/*output_mean=*/output_mean,
|
||||
/*output_inv_stddev=*/output_inv_stddev,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto output_grad_scale = GetAllocationSlice(*custom_call, {1});
|
||||
auto output_grad_offset = GetAllocationSlice(*custom_call, {2});
|
||||
AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*mean=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*output_grad_data=*/output_grad_data,
|
||||
/*output_grad_scale=*/output_grad_scale,
|
||||
/*output_grad_offset=*/output_grad_offset,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
|
||||
context_->GetThunkInfo(custom_call), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice, tuple_result_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
if (operand_buffer != a_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/a_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
|
||||
}
|
||||
|
||||
thunks.push_back(absl::make_unique<CholeskyThunk>(
|
||||
options, a_buffer, workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n,
|
||||
custom_call));
|
||||
context_->GetThunkInfo(custom_call), options, a_buffer,
|
||||
workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(custom_call), std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
ShapeTree<BufferAllocation::Slice> result_slices =
|
||||
get_slices_for_instr(custom_call);
|
||||
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
|
||||
call_target, std::move(operand_slices), std::move(result_slices),
|
||||
Cast<HloCustomCallInstruction>(custom_call)->opaque(), custom_call));
|
||||
context_->GetThunkInfo(custom_call), call_target,
|
||||
std::move(operand_slices), std::move(result_slices),
|
||||
Cast<HloCustomCallInstruction>(custom_call)->opaque()));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
@ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
auto destination_buffer = GetAllocationSlice(*hlo);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(hlo),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
|
||||
}
|
||||
|
||||
thunks.push_back(BuildTriangularSolveThunk(hlo));
|
||||
@ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(hlo), std::move(thunks)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
CHECK(hlo);
|
||||
Thunk::ThunkInfo info;
|
||||
info.hlo_instruction = hlo;
|
||||
return info;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@ -36,6 +36,7 @@ class ThunkEmitter {
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const = 0;
|
||||
virtual int64 ByteSizeOf(const Shape& shape) const = 0;
|
||||
virtual absl::string_view platform_name() const = 0;
|
||||
virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
|
||||
|
||||
virtual ~EmissionContext() = default;
|
||||
};
|
||||
|
||||
@ -32,12 +32,12 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
TriangularSolveThunk::TriangularSolveThunk(
|
||||
const TriangularSolveOptions& options,
|
||||
ThunkInfo thunk_info, const TriangularSolveOptions& options,
|
||||
const BufferAllocation::Slice& a_buffer,
|
||||
const BufferAllocation::Slice& b_buffer, PrimitiveType type,
|
||||
int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
|
||||
int64 b_batch_stride, const HloInstruction* hlo)
|
||||
: Thunk(Kind::kTriangularSolve, hlo),
|
||||
int64 b_batch_stride)
|
||||
: Thunk(Kind::kTriangularSolve, thunk_info),
|
||||
uplo_(options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper),
|
||||
side_(options.left_side() ? se::blas::Side::kLeft
|
||||
|
||||
@ -38,12 +38,12 @@ namespace gpu {
|
||||
// Thread-compatible.
|
||||
class TriangularSolveThunk : public Thunk {
|
||||
public:
|
||||
TriangularSolveThunk(const TriangularSolveOptions& options,
|
||||
TriangularSolveThunk(ThunkInfo thunk_info,
|
||||
const TriangularSolveOptions& options,
|
||||
const BufferAllocation::Slice& a_buffer,
|
||||
const BufferAllocation::Slice& b_buffer,
|
||||
PrimitiveType type, int64 batch_size, int64 m, int64 n,
|
||||
int64 a_batch_stride, int64 b_batch_stride,
|
||||
const HloInstruction* hlo);
|
||||
int64 a_batch_stride, int64 b_batch_stride);
|
||||
|
||||
TriangularSolveThunk(const TriangularSolveThunk&) = delete;
|
||||
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user