Merge pull request #5 from tensorflow/master

downstream merge ~3
This commit is contained in:
tg-at-google 2020-07-13 09:44:39 -04:00 committed by GitHub
commit 9e475fecd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
276 changed files with 7539 additions and 2007 deletions

View File

@ -11,10 +11,6 @@
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
`TF_StringEncodedSize` are no longer relevant and have been removed; see
core/platform/ctstring.h for string access/modification in C.
* In batching library, rename parameter
SharedBatchScheduler::QueueOptions::max_batch_size to a more accurate name
(input_batch_size_limit) for a recent feature to enable split of large batch
sizes.
## Known Caveats

View File

@ -171,6 +171,87 @@ cc_library(
],
)
cc_library(
name = "gradients",
srcs = [
"gradients.cc",
"gradients_internal.h",
],
hdrs = [
"gradients.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gradients_internal",
srcs = [
"gradients.cc",
],
hdrs = [
"gradients.h",
"gradients_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "gradients_test",
size = "small",
srcs = [
"gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_tensor_handle",
hdrs = ["abstract_tensor_handle.h"],
@ -747,6 +828,7 @@ filegroup(
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
"gradients.cc", # Uses RTTI.
"*test*",
"*dlpack*",
],

View File

@ -725,13 +725,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
tfrt::SmallVector<std::string, 4> op_handler_chains;
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(new tfrt::ContextInterface(
op_handler_chains, device_attributes, opts->async));
return tensorflow::wrap(new tfrt::ContextInterface(opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;

View File

@ -0,0 +1,400 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
return Status::OK();
}
int64 ToId(AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
ctx_ = other.ctx_;
}
TapeTensor::~TapeTensor() {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
tensorflow::DataType TapeTensor::GetDType() const {
return handle_->DataType();
}
AbstractTensorHandle* TapeTensor::OnesLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
// TODO(srbs): It seems like this is used only for performance optimization
// and not for correctness. The only downside of keeping this 1 seems to be
// that the gradient accumulation is unbounded and we will never
// aggressively aggregate accumulated gradients to recover memory.
// Revisit and fix.
return 1;
}
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* TapeVSpace::AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
if (gradient_tensors.size() == 1) {
return gradient_tensors[0];
}
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
s = op->AddInputList(gradient_tensors);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
return backward_function->Compute(output_gradients, result);
}
// Looks up the ID of a Gradient.
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
return ToId(tensor);
}
// Converts a Gradient to a TapeTensor.
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
return TapeTensor(g, ctx_);
}
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release();
}
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInput(input));
forward_op_->inputs.push_back(input);
return Status::OK();
}
Status AddInputList(AbstractOperation* op_,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_->inputs.push_back(input);
}
return Status::OK();
}
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
const char* data, size_t length,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
return op_->SetAttrString(attr_name, data, length);
}
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
return op_->SetAttrInt(attr_name, value);
}
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrFloat(attr_name, value);
}
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrBool(attr_name, value);
}
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
DataType value, ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrType(attr_name, value);
}
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
const int64_t* dims, const int num_dims,
ForwardOperation* forward_op_) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_->attrs.Set(attr_name, proto);
return op_->SetAttrShape(attr_name, dims, num_dims);
}
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
const AbstractOperation* value,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
const char* value, size_t length,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
AbstractTensorInterface* tensor,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation* forward_op_) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_->attrs.Set(attr_name, v);
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
const float* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return op_->SetAttrFloatList(attr_name, values, num_values);
}
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
const int64_t* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return op_->SetAttrIntList(attr_name, values, num_values);
}
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return op_->SetAttrTypeList(attr_name, values, num_values);
}
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation* forward_op_) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return op_->SetAttrBoolList(attr_name, values, num_values);
}
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation* forward_op_) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation* forward_op_, Tape* tape,
const GradientRegistry& registry) {
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_->inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
for (int i = 0; i < forward_op_->inputs.size(); i++) {
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
},
[](GradientFunction* ptr) {
if (ptr) {
delete ptr;
}
});
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,171 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
namespace tensorflow {
namespace gradients {
// =============== Experimental C++ API for computing gradients ===============
// Sample gradient function:
//
// class AddGradientFunction : public GradientFunction {
// public:
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
// grad_outputs->resize(2);
// (*grad_outputs)[0] = grad_inputs[0];
// (*grad_outputs)[1] = grad_inputs[0];
// return Status::OK();
// }
// ~AddGradientFunction() override {}
// };
//
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
// // More complex gradient functions can use inputs/attrs etc. from the
// // forward `op`.
// return new AddGradientFunction;
// }
//
// Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer);
// }
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
struct ForwardOperation {
public:
string op_name;
std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs;
AttrBuilder attrs;
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
// the gradient graph. See documentation of `TapeTensor` for more details.
int64 ToId(AbstractTensorHandle* t);
// Wrapper for a tensor output of an operation executing under a tape.
//
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
// the op that produced it (or -1 if this tensor was watched using
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
// inputs and outputs and the gradient function These data structures combined
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
TapeTensor(const TapeTensor& other);
~TapeTensor();
tensorflow::int64 GetID() const;
tensorflow::DataType GetDType() const;
AbstractTensorHandle* OnesLike() const;
AbstractTensorHandle* ZerosLike() const;
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
AbstractContext* ctx_;
};
// Vector space for actually computing gradients. Implements methods for calling
// the backward function with incoming gradients and returning the outgoing
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
// Returns the number of elements in the gradient tensor.
int64 NumElements(AbstractTensorHandle* tensor) const override;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
// Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override;
// Converts a Gradient to a TapeTensor.
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
void MarkAsResult(AbstractTensorHandle* gradient) const override;
void DeleteGradient(AbstractTensorHandle* gradient) const override;
private:
// The context where the aggregation op `Add` is to be created.
AbstractContext* ctx_;
};
// A tracing/immediate-execution agnostic tape.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
// Records the op name in the `ForwardOperation`.
Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
ForwardOperation*);
// Records the inputs in the `ForwardOperation`.
Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
Status AddInputList(AbstractOperation*,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation*);
// Sets the attrs in the `ForwardOperation`.
Status SetAttrString(AbstractOperation*, const char* attr_name,
const char* data, size_t length, ForwardOperation*);
Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
ForwardOperation*);
Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
ForwardOperation*);
Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
ForwardOperation*);
Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
ForwardOperation*);
Status SetAttrShape(AbstractOperation*, const char* attr_name,
const int64_t* dims, const int num_dims, ForwardOperation*);
Status SetAttrFunction(AbstractOperation*, const char* attr_name,
const AbstractOperation* value, ForwardOperation*);
Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
const char* value, size_t length, ForwardOperation*);
Status SetAttrTensor(AbstractOperation*, const char* attr_name,
AbstractTensorInterface* tensor, ForwardOperation*);
Status SetAttrStringList(AbstractOperation*, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation*);
Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
const float* values, int num_values, ForwardOperation*);
Status SetAttrIntList(AbstractOperation*, const char* attr_name,
const int64_t* values, int num_values, ForwardOperation*);
Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation*);
Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation*);
Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation*);
Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation*);
// Make the call to `Tape::RecordOperation`.
Status Execute(AbstractOperation*, AbstractContext*,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation*, Tape*, const GradientRegistry&);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_

View File

@ -0,0 +1,328 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// =================== Register gradients for Add ============================
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
}
// =================== End gradient registrations ============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto add_output : add_outputs) {
add_output->Release();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs `model` maybe wrapped in a function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
for (auto func_input : func_inputs) {
func_input->Release();
}
AbstractFunction* func = nullptr;
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Release();
output_list.outputs[1]->Release();
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
TEST_P(CppGradients, TestAddGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x + y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Release();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Release();
TF_DeleteTensor(result_tensor);
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -66,11 +66,26 @@ cc_library(
":file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "ram_file_block_cache_test",
size = "small",
srcs = ["ram_file_block_cache_test.cc"],
deps = [
":ram_file_block_cache",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform/cloud:now_seconds_env",
],
)
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [

View File

@ -45,8 +45,6 @@ limitations under the License.
#include <type_traits>
#include <utility>
#include "tensorflow/core/platform/macros.h"
namespace tf_gcs_filesystem {
// A move-only RAII object that calls a stored cleanup functor when

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string.h>
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
@ -556,6 +557,111 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
// TODO(vnvo2409): This approach can cause a problem when our path is
// `path/to/dir` and there is an object with key `path/to/directory`. Will be
// fixed when refactoring.
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
// We consider a path exists if there is at least one object whose key
// contains the path.
return TF_SetStatus(status, TF_OK, "");
}
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return false;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK)
return true;
else
return false;
}
// We check if there is an object with this key on the GCS server.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
TF_SetStatus(status, TF_OK, "");
if (metadata->name().back() == '/')
return true;
else
return false;
}
// If there is no object with this key on the GCS server. We check if there is
// any object whose key contains that path.
MaybeAppendSlash(&object);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return false;
}
TF_SetStatus(status, TF_OK, "");
return true;
}
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
return false;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK) {
stats->is_directory = true;
stats->length = 0;
stats->mtime_nsec = 0;
}
return;
}
if (IsDirectory(filesystem, path, status)) {
stats->is_directory = true;
stats->length = 0;
stats->mtime_nsec = 0;
return TF_SetStatus(status, TF_OK, "");
}
if (TF_GetCode(status) == TF_OK) {
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
stats->is_directory = false;
stats->length = metadata.value().size();
stats->mtime_nsec = metadata.value()
.time_storage_class_updated()
.time_since_epoch()
.count();
}
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,

View File

@ -51,7 +51,7 @@ class RamFileBlockCache : public FileBlockCache {
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
BlockFetcher block_fetcher,
std::function<uint64_t()> timer_seconds)
std::function<uint64_t()> timer_seconds = TF_NowSeconds)
: block_size_(block_size),
max_bytes_(max_bytes),
max_staleness_(max_staleness),

View File

@ -0,0 +1,601 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
#include <cstring>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/cloud/now_seconds_env.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache,
const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
out->clear();
out->resize(n, 0);
size_t bytes_transferred = 0;
TF_Status status;
cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status);
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
return status.status;
}
TEST(RamFileBlockCacheTest, IsCacheEnabled) {
auto fetcher = [](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
// Do nothing.
return TF_SetStatus(status, TF_OK, "");
};
tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache3(0, 32, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache4(16, 32, 0, fetcher);
EXPECT_FALSE(cache1.IsCacheEnabled());
EXPECT_FALSE(cache2.IsCacheEnabled());
EXPECT_FALSE(cache3.IsCacheEnabled());
EXPECT_TRUE(cache4.IsCacheEnabled());
}
TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
string filename = "file";
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
std::vector<char> out;
// First read.
EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123));
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
EXPECT_EQ(calls, 1);
// Second read. Hit cache.
EXPECT_TRUE(cache.ValidateAndUpdateFileSignature(filename, 123));
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
EXPECT_EQ(calls, 1);
// Third read. File signatures are different.
EXPECT_FALSE(cache.ValidateAndUpdateFileSignature(filename, 321));
TF_EXPECT_OK(ReadCache(&cache, filename, 0, 16, &out));
EXPECT_EQ(calls, 2);
}
TEST(RamFileBlockCacheTest, PassThrough) {
const string want_filename = "foo/bar";
const size_t want_offset = 42;
const size_t want_n = 1024;
int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset,
size_t got_n, char* buffer, size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n);
calls++;
memset(buffer, 'x', got_n);
*bytes_transferred = got_n;
return TF_SetStatus(status, TF_OK, "");
};
// If block_size, max_bytes, or both are zero, or want_n is larger than
// max_bytes the cache is a pass-through.
tf_gcs_filesystem::RamFileBlockCache cache1(1, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache2(0, 1, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache3(0, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache4(1000, 1000, 0, fetcher);
std::vector<char> out;
TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 1);
TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 2);
TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 3);
TF_EXPECT_OK(ReadCache(&cache4, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 4);
}
TEST(RamFileBlockCacheTest, BlockAlignment) {
// Initialize a 256-byte buffer. This is the file underlying the reads we'll
// do in this test.
const size_t size = 256;
std::vector<char> buf;
for (int i = 0; i < size; i++) {
buf.push_back(i);
}
// The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
if (offset < buf.size()) {
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
memcpy(buffer, buf.data() + offset, bytes_to_copy);
*bytes_transferred = bytes_to_copy;
} else {
*bytes_transferred = 0;
}
return TF_SetStatus(status, TF_OK, "");
};
for (size_t block_size = 2; block_size <= 4; block_size++) {
// Make a cache of N-byte block size (1 block) and verify that reads of
// varying offsets and lengths return correct data.
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher);
for (size_t offset = 0; offset < 10; offset++) {
for (size_t n = block_size - 2; n <= block_size + 2; n++) {
std::vector<char> got;
TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
// Verify the size of the read.
if (offset + n <= size) {
// Expect a full read.
EXPECT_EQ(got.size(), n) << "block size = " << block_size
<< ", offset = " << offset << ", n = " << n;
} else {
// Expect a partial read.
EXPECT_EQ(got.size(), size - offset)
<< "block size = " << block_size << ", offset = " << offset
<< ", n = " << n;
}
// Verify the contents of the read.
std::vector<char>::const_iterator begin = buf.begin() + offset;
std::vector<char>::const_iterator end =
offset + n > buf.size() ? buf.end() : begin + n;
std::vector<char> want(begin, end);
EXPECT_EQ(got, want) << "block size = " << block_size
<< ", offset = " << offset << ", n = " << n;
}
}
}
}
TEST(RamFileBlockCacheTest, CacheHits) {
const size_t block_size = 16;
std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer,
size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset);
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
const uint32 block_count = 256;
tf_gcs_filesystem::RamFileBlockCache cache(
block_size, block_count * block_size, 0, fetcher);
std::vector<char> out;
out.resize(block_count, 0);
// The cache has space for `block_count` blocks. The loop with i = 0 should
// fill the cache, and the loop with i = 1 should be all cache hits. The
// fetcher checks that it is called once and only once for each offset (to
// fetch the corresponding block).
for (int i = 0; i < 2; i++) {
for (int j = 0; j < block_count; j++) {
TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
}
}
}
TEST(RamFileBlockCacheTest, OutOfRange) {
// Tests reads of a 24-byte file with block size 16.
const size_t block_size = 16;
const size_t file_size = 24;
bool first_block = false;
bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
size_t bytes_to_copy = 0;
if (offset == 0) {
// The first block (16 bytes) of the file.
memset(buffer, 'x', n);
bytes_to_copy = n;
first_block = true;
} else if (offset == block_size) {
// The second block (8 bytes) of the file.
bytes_to_copy = file_size - block_size;
memset(buffer, 'x', bytes_to_copy);
second_block = true;
}
*bytes_transferred = bytes_to_copy;
return TF_SetStatus(status, TF_OK, "");
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher);
std::vector<char> out;
// Reading the first 16 bytes should be fine.
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
EXPECT_TRUE(first_block);
EXPECT_EQ(out.size(), block_size);
// Reading at offset file_size + 4 will read the second block (since the read
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
// OutOfRange because the offset is past the end of the 24-byte file.
Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
EXPECT_TRUE(second_block);
// Reading the second full block will return 8 bytes, from a cache hit.
second_block = false;
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_FALSE(second_block);
EXPECT_EQ(out.size(), file_size - block_size);
}
TEST(RamFileBlockCacheTest, Inconsistent) {
// Tests the detection of interrupted reads leading to partially filled blocks
// where we expected complete blocks.
const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_GE(n, 1);
memset(buffer, 'x', 1);
*bytes_transferred = 1;
return TF_SetStatus(status, TF_OK, "");
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0,
fetcher);
std::vector<char> out;
// Read the second block; this should yield an OK status and a single byte.
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_EQ(out.size(), 1);
// Now read the first block; this should yield an INTERNAL error because we
// had already cached a partial block at a later position.
Status status = ReadCache(&cache, "", 0, block_size, &out);
EXPECT_EQ(status.code(), error::INTERNAL);
}
TEST(RamFileBlockCacheTest, LRU) {
const size_t block_size = 16;
std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer,
size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) {
EXPECT_EQ(offset, calls.front());
calls.pop_front();
}
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
const uint32 block_count = 2;
tf_gcs_filesystem::RamFileBlockCache cache(
block_size, block_count * block_size, 0, fetcher);
std::vector<char> out;
// Read blocks from the cache, and verify the LRU behavior based on the
// fetcher calls that the cache makes.
calls.push_back(0);
// Cache miss - drains an element from `calls`.
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Cache hit - does not drain an element from `calls`.
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
calls.push_back(block_size);
// Cache miss followed by cache hit.
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
calls.push_back(2 * block_size);
// Cache miss followed by cache hit. Causes eviction of LRU element.
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// LRU element was at offset 0. Cache miss.
calls.push_back(0);
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Element at 2 * block_size is still in cache, and this read should update
// its position in the LRU list so it doesn't get evicted by the next read.
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// Element at block_size was evicted. Reading this element will also cause
// the LRU element (at 0) to be evicted.
calls.push_back(block_size);
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
// Element at 0 was evicted again.
calls.push_back(0);
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
}
TEST(RamFileBlockCacheTest, MaxStaleness) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
std::vector<char> out;
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
// Create a cache with max staleness of 2 seconds, and verify that it works as
// expected.
tf_gcs_filesystem::RamFileBlockCache cache1(
8, 16, 2 /* max staleness */, fetcher,
[&env]() { return env->NowSeconds(); });
// Execute the first read to load the block.
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Now advance the clock one second at a time and redo the read. The call
// count should advance every 3 seconds (i.e. every time the staleness is
// greater than 2).
for (int i = 1; i <= 10; i++) {
env->SetNowSeconds(i + 1);
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1 + i / 3);
}
// Now create a cache with max staleness of 0, and verify that it also works
// as expected.
calls = 0;
env->SetNowSeconds(0);
tf_gcs_filesystem::RamFileBlockCache cache2(
8, 16, 0 /* max staleness */, fetcher,
[&env]() { return env->NowSeconds(); });
// Execute the first read to load the block.
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Advance the clock by a huge amount and verify that the cached block is
// used to satisfy the read.
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
}
TEST(RamFileBlockCacheTest, RemoveFile) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) {
// The first block is lower case and all subsequent blocks are upper case.
c = toupper(c);
}
memset(buffer, c, n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
// This cache has space for 4 blocks; we'll read from two files.
const size_t n = 3;
tf_gcs_filesystem::RamFileBlockCache cache(8, 32, 0, fetcher);
std::vector<char> out;
std::vector<char> a(n, 'a');
std::vector<char> b(n, 'b');
std::vector<char> A(n, 'A');
std::vector<char> B(n, 'B');
// Fill the cache.
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 1);
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 2);
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
EXPECT_EQ(calls, 3);
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// All four blocks should be in the cache now.
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// Remove the blocks from "a".
cache.RemoveFile("a");
// Both blocks from "b" should still be there.
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// The blocks from "a" should not be there.
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 5);
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 6);
}
TEST(RamFileBlockCacheTest, Prune) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
std::vector<char> out;
// Our fake environment is initialized with the current timestamp.
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
uint64 now = Env::Default()->NowSeconds();
env->SetNowSeconds(now);
tf_gcs_filesystem::RamFileBlockCache cache(
8, 32, 1 /* max staleness */, fetcher,
[&env]() { return env->NowSeconds(); });
// Read three blocks into the cache, and advance the timestamp by one second
// with each read. Start with a block of "a" at the current timestamp `now`.
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
// Now load a block of a different file "b" at timestamp `now` + 1
env->SetNowSeconds(now + 1);
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
// Now load a different block of file "a" at timestamp `now` + 1. When the
// first block of "a" expires, this block should also be removed because it
// also belongs to file "a".
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
// Ensure that all blocks are in the cache (i.e. reads are cache hits).
EXPECT_EQ(cache.CacheSize(), 24);
EXPECT_EQ(calls, 3);
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake timestamp so that "a" becomes stale via its first block.
env->SetNowSeconds(now + 2);
// The pruning thread periodically compares env->NowSeconds() with the oldest
// block's timestamp to see if it should evict any files. At the current fake
// timestamp of `now` + 2, file "a" is stale because its first block is stale,
// but file "b" is not stale yet. Thus, once the pruning thread wakes up (in
// one second of wall time), it should remove "a" and leave "b" alone.
uint64 start = Env::Default()->NowSeconds();
do {
Env::Default()->SleepForMicroseconds(100000);
} while (cache.CacheSize() == 24 && Env::Default()->NowSeconds() - start < 3);
// There should be one block left in the cache, and it should be the first
// block of "b".
EXPECT_EQ(cache.CacheSize(), 8);
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake time to `now` + 3, at which point "b" becomes stale.
env->SetNowSeconds(now + 3);
// Wait for the pruner to remove "b".
start = Env::Default()->NowSeconds();
do {
Env::Default()->SleepForMicroseconds(100000);
} while (cache.CacheSize() == 8 && Env::Default()->NowSeconds() - start < 3);
// The cache should now be empty.
EXPECT_EQ(cache.CacheSize(), 0);
}
TEST(RamFileBlockCacheTest, ParallelReads) {
// This fetcher won't respond until either `callers` threads are calling it
// concurrently (at which point it will respond with success to all callers),
// or 10 seconds have elapsed (at which point it will respond with an error).
const int callers = 4;
BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug.
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"desired concurrency not reached");
}
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
const int block_size = 8;
tf_gcs_filesystem::RamFileBlockCache cache(
block_size, 2 * callers * block_size, 0, fetcher);
std::vector<std::unique_ptr<Thread>> threads;
threads.reserve(callers);
for (int i = 0; i < callers; i++) {
threads.emplace_back(
Env::Default()->StartThread({}, "caller", [block_size, &cache, i]() {
std::vector<char> out;
TF_EXPECT_OK(
ReadCache(&cache, "a", i * block_size, block_size, &out));
std::vector<char> x(block_size, 'x');
EXPECT_EQ(out, x);
}));
}
// The `threads` destructor blocks until the threads can be joined, once their
// respective reads finish (which happens once they are all concurrently being
// executed, or 10 seconds have passed).
}
TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
// Concurrent reads to the same file blocks should be de-duplicated.
const size_t block_size = 16;
int num_requests = 0;
Notification notification;
auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0);
num_requests++;
memset(buffer, 'x', n);
*bytes_transferred = n;
notification.Notify();
// Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
return TF_SetStatus(status, TF_OK, "");
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher);
// Fork off thread for parallel read.
std::unique_ptr<Thread> concurrent(
Env::Default()->StartThread({}, "concurrent", [block_size, &cache] {
std::vector<char> out;
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
}));
notification.WaitForNotification();
std::vector<char> out;
TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
EXPECT_EQ(1, num_requests);
}
TEST(RamFileBlockCacheTest, Flush) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
};
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
std::vector<char> out;
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
EXPECT_EQ(calls, 1);
cache.Flush();
TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
EXPECT_EQ(calls, 2);
}
} // namespace
} // namespace tensorflow

View File

@ -97,6 +97,11 @@ void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
kernel_builder->cc_builder->HostMemory(arg_name);
}
void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
int32_t priority_number) {
kernel_builder->cc_builder->Priority(priority_number);
}
namespace tensorflow {
namespace {

View File

@ -107,6 +107,10 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint(
TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
TF_KernelBuilder* kernel_builder, const char* arg_name);
// Specify a priority number for this kernel.
TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
TF_KernelBuilder* kernel_builder, int32_t priority_number);
// Register the given kernel builder with the TensorFlow runtime. If
// registration fails, the given status will be populated.
//

View File

@ -359,6 +359,7 @@ cc_library(
":map_lmhlo_to_scalar_op",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "absl/memory/memory.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -692,7 +693,8 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
if (valueAttr.getType().getRank() != 0) return failure();
auto stdConstOp =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
ValueRange());
rewriter.eraseOp(constOp);
return success();
}
@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
AffineDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);

View File

@ -329,7 +329,7 @@ func @constant(%value: memref<i32>) {
return
}
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32>
// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref<i32>
// -----

View File

@ -152,6 +152,14 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
uint32_t flags =
is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
// Rejects if quantized tensors have zero scales.
for (float scale : quant_params.scale) {
if (scale == 0) {
return errors::InvalidArgument(
"Quantized tensors must have non-zero scales");
}
}
// Scale size can't be zero as it is checked before.
if (quant_params.scale.size() != 1) {
llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {

View File

@ -1676,12 +1676,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [
}
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
FixedOutputRangeInterface,
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
FixedOutputRangeInterface]> {
let summary = "L2 Normalize Operator";
let description = [{
@ -1703,29 +1698,12 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = output().getType().cast<ShapedType>();
if (!result_type.getElementType().isa<FloatType>()) return {};
Builder builder(result_type.getContext());
// Only support 8-bits
if (bit_width != 8) return {};
IntegerType storage_type = builder.getIntegerType(bit_width);
double scale = 1.0 / 128;
int64_t zero_point, storage_min, storage_max;
if (is_signed) {
zero_point = 0;
storage_min = -128;
storage_max = 127;
} else {
zero_point = 128;
storage_min = 0;
storage_max = 255;
}
return quant::UniformQuantizedType::getChecked(
is_signed, storage_type, result_type.getElementType(), scale,
zero_point, storage_min, storage_max, builder.getUnknownLoc());
auto result_type = output().getType();
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
/*scale=*/1.0 / 128, /*zero_point=*/0);
}
}];
}
@ -1834,10 +1812,6 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
let summary = "Logistic operator";
@ -1854,29 +1828,11 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = y().getType().cast<ShapedType>();
if (!result_type.getElementType().isa<FloatType>()) return {};
Builder builder(result_type.getContext());
// Only support 8-bits
if (bit_width != 8) return {};
IntegerType storage_type = builder.getIntegerType(bit_width);
double scale = 1.0 / 256;
int64_t zero_point, storage_min, storage_max;
if (is_signed) {
zero_point = -128;
storage_min = -128;
storage_max = 127;
} else {
zero_point = 0;
storage_min = 0;
storage_max = 255;
}
return quant::UniformQuantizedType::getChecked(
is_signed, storage_type, result_type.getElementType(), scale,
zero_point, storage_min, storage_max, builder.getUnknownLoc());
auto result_type = y().getType();
// zero_point = 0
// scale = 1. / (max_value + 1)
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
/*scale=*/1.0 / 256, /*zero_point=*/-128);
}
}];
}
@ -1905,10 +1861,7 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
SameOperandsAndResultShape,
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
// zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
FixedResultScale<UInt8UniformQuantizedType<255, 625, -4>>]> {
FixedOutputRangeInterface]> {
let summary = "Log softmax operator";
let description = [{
@ -1922,6 +1875,18 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
let hasOptions = 1;
let extraClassDeclaration = [{
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = output().getType();
// zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1)
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
/*scale=*/16.0 / 256, /*zero_point=*/127);
}
}];
}
// TODO(ashwinm): Revisit the granularity of the PredOpTraits. We could
@ -2833,10 +2798,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankRange<0, 1, 4>,
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
let summary = "Softmax operator";
@ -2854,6 +2816,18 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1;
let extraClassDeclaration = [{
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = output().getType();
// zero_point = 0
// scale = 1. / (max_value + 1)
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
/*scale=*/1.0 / 256, /*zero_point=*/-128);
}
}];
}
def TFL_SqrtOp: TFL_Op<"sqrt", [
@ -2959,11 +2933,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [
SameOperandsAndResultShape,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
let summary = "Hyperbolic tangent operator";
@ -2985,6 +2955,19 @@ def TFL_TanhOp: TFL_Op<"tanh", [
state.addTypes(input.getType());
}]>
];
let extraClassDeclaration = [{
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = output().getType();
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
return quant::GetFixedOutputRange(is_signed, bit_width, result_type,
/*scale=*/1.0 / 128, /*zero_point=*/0);
}
}];
}
def TFL_TileOp: TFL_Op<"tile", [

View File

@ -794,16 +794,18 @@ bool QuantizationDriver::PropagateParams() {
}
// TODO(fengliuai): make the bit width configurable.
auto spec = GetQuantSpec(op);
auto key = std::make_pair(8, is_signed_);
auto &restricted_outputs = spec->restricted_output_params[key];
for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
// The restrict can be nullptr if the result has been quantized.
if (auto params = restricted_outputs[i]) {
changed |= SetResultParams(op, i, params);
if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
// TODO(fengliuai): different result can have different fixed range.
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
for (auto i = 0; i < op->getNumResults(); ++i) {
// The range is null if the result has been quantized.
if (params) {
changed |= SetResultParams(op, i, params);
}
}
}
auto spec = GetQuantSpec(op);
for (auto &it : spec->biases_params) {
auto params =
GetBiasParams(op, it.first, it.second.first, it.second.second);

View File

@ -449,7 +449,7 @@ static bool PreferResultScale(Operation* op) {
// only considers the ops with restricted output params.
static bool IsStatsRedundant(Operation* op,
OpQuantSpecGetter op_quant_spec_getter) {
return !op_quant_spec_getter(op)->restricted_output_params.empty();
return llvm::isa<FixedOutputRangeInterface>(op);
}
bool RemoveRedundantStatsOps(mlir::FuncOp func,
@ -469,7 +469,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
// Step 1: forward pass: propagate any value scales which are not produces
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales
// which are produced by the `restricted_output_params`.
// which are produced by the ops with the `FixedOutputRangeInterface`.
// Note that we don't propagate across the multiple-operands
// `SameOperandsAndResultsScale` ops like `concatenation`.
func.walk(
@ -594,5 +594,27 @@ LogicalResult VerifySameScales(Operation* op) {
}
return success();
}
quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
Type tensor_type, double scale,
int64_t zero_point,
int64_t storage_min,
int64_t storage_max) {
auto result_type = tensor_type.cast<ShapedType>();
if (!result_type.getElementType().isa<FloatType>()) return {};
Builder builder(result_type.getContext());
// Only support 8-bits
if (bit_width != 8) return {};
IntegerType storage_type = builder.getIntegerType(bit_width);
if (!is_signed) {
zero_point += 128;
storage_min += 128;
storage_max += 128;
}
return quant::UniformQuantizedType::getChecked(
is_signed, storage_type, result_type.getElementType(), scale, zero_point,
storage_min, storage_max, builder.getUnknownLoc());
}
} // namespace quant
} // namespace mlir

View File

@ -395,8 +395,6 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
@ -502,6 +500,13 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool RemoveRedundantStatsOps(mlir::FuncOp func,
OpQuantSpecGetter op_quant_spec_getter);
// Given quantization parameters for int8, compute the quantization parameters
// for uint if it is required, and wrap the result in an UniformQuantizedType.
quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
Type tensor_type, double scale,
int64_t zero_point,
int64_t storage_min = -128,
int64_t storage_max = 127);
} // namespace quant
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FileCheck %s
module{
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
@ -453,3 +453,31 @@ func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %
// CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
// -----
module {
func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
%0 = "tf.Const"() {value = dense<1> : tensor<1x10xi32>} : () -> tensor<1x10xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<i32>
return %0, %1 : tensor<1x10xi32>, tensor<i32>
}
// CHECK: func @nms_padded(%[[VAL_119:.*]]: tensor<100x4xf32>, %[[VAL_120:.*]]: tensor<100xf32>, %[[VAL_121:.*]]: tensor<i32>, %[[VAL_122:.*]]: tensor<f32>, %[[VAL_123:.*]]: tensor<f32>, %[[VAL_124:.*]]: tensor<i1>, %[[VAL_125:.*]]: tensor<i1>, %[[VAL_126:.*]]: tensor<i1>, %[[VAL_127:.*]]: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
// CHECK: %[[VAL_128:.*]], %[[VAL_129:.*]] = "tfl.non_max_suppression_v4"(%[[VAL_119]], %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]]) : (tensor<100x4xf32>, tensor<100xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>)
// CHECK: return %[[VAL_128]], %[[VAL_129]] : tensor<1x10xi32>, tensor<i32>
// CHECK: }
}
// -----
module {
// expected-error @+1 {{Invalid number of results from non_max_suppression_padded_v2}}
func @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
// expected-error @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}}
func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
}

View File

@ -188,20 +188,16 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
if (saved_model_version == 2) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
if (!module_or.status().ok()) return module_or.status();
return module_or.ConsumeValueOrDie();
} else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
if (!module_or.status().ok()) return module_or.status();
return module_or.ConsumeValueOrDie();
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");

View File

@ -57,6 +57,7 @@ namespace {
constexpr char kTFAPIImplements[] = "tf.api_implements";
constexpr char kTfTextAPIPRefix[] = "tftext:";
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
// Abstracts the conversion of the embedded lookup composite function.
class ConvertEmbeddedLookupFunc {
@ -94,6 +95,59 @@ class ConvertEmbeddedLookupFunc {
FuncOp func_;
};
// Abstracts the conversion of the padded NMS composite function.
class ConvertNMSPaddedFunc {
public:
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
void RewriteFunc() {
func_.setAttr(kTFImplements,
StringAttr::get(kTfNMSPadded, func_.getContext()));
Value boxes = func_.getArgument(0);
Value scores = func_.getArgument(1);
Value max_output_size = func_.getArgument(2);
Value iou_threshold = func_.getArgument(3);
Value score_threshold = func_.getArgument(4);
auto output_type0 = func_.getType().getResult(0);
auto output_type1 = func_.getType().getResult(1);
OpBuilder builder(func_.getBody());
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
func_.getLoc(), output_type0, output_type1, boxes, scores,
max_output_size, iou_threshold, score_threshold);
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
}
LogicalResult VerifySignature() {
// Verify high-level function signature.
// Relevant argument characteristics are checked by the TFL op definition.
if (func_.getNumArguments() < 5) {
return func_.emitError()
<< "Invalid number of arguments to "
"non_max_suppression_padded_v2 (need atleast 5): "
<< func_.getNumArguments();
}
if (func_.getType().getNumResults() != 2) {
return func_.emitError() << "Invalid number of results from "
"non_max_suppression_padded_v2 (need 2): "
<< func_.getType().getNumResults();
}
// The TFLite fused op does not support batching yet.
// TODO(b/158709815): Add support for batches with padded NMS.
auto boxes_type =
func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
return func_.emitError() << "TFLite does not support batched input for "
"non_max_suppression_padded";
}
return success();
}
private:
FuncOp func_;
};
// This pass uses mechanisms listed in RFC:
// https://github.com/tensorflow/community/pull/113
// It prepares composite functions that are attributed to indicate
@ -139,6 +193,14 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
return signalPassFailure();
}
} else if (attr.getValue() == kTfNMSPadded) {
func.eraseBody();
func.addEntryBlock();
ConvertNMSPaddedFunc convert_nms_padded(func);
if (failed(convert_nms_padded.VerifySignature())) {
return signalPassFailure();
}
convert_nms_padded.RewriteFunc();
}
}

View File

@ -188,12 +188,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
if (!is_enabled) {
VLOG(0) << "None of the MLIR optimization passes are enabled "
<< "(registered" << registry_->passes().size() << " passes)";
<< "(registered " << registry_->passes().size() << " passes)";
return Status::OK();
}
VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes "
<< "(registered" << registry_->passes().size() << " passes)";
<< "(registered " << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info;
RegisterDialects();

View File

@ -22,6 +22,7 @@ tf_python_pybind_extension(
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:StandardOps",
"@pybind11",
],

View File

@ -15,7 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -29,6 +33,21 @@ PYBIND11_MODULE(mlir_wrapper, m) {
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
});
m.def("verify", [](std::string input) {
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
mlir::MLIRContext ctx;
auto module = mlir::parseSourceFile(SM, &ctx);
if (!module) {
return false;
}
if (failed(mlir::verify(*module))) {
module->emitError("Invalid MLIR module: failed verification.");
return false;
}
return true;
});
init_basic_classes(m);
init_types(m);

View File

@ -1194,6 +1194,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -3540,7 +3540,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
}
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
@ -3561,15 +3561,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
@ -3585,6 +3576,27 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
}
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2
);
}
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
}
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
let summary = "Gather slices from `params` according to `indices`.";
@ -8266,6 +8278,39 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]>
];
}
def TF_SelfAdjointEigV2Op : TF_Op<"SelfAdjointEigV2", [NoSideEffect]> {
let summary = [{
Computes the eigen decomposition of one or more square self-adjoint matrices.
}];
let description = [{
Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
are sorted in non-decreasing order.
```python
# a is a tensor.
# e is a tensor of eigenvalues.
# v is a tensor of eigenvectors.
e, v = self_adjoint_eig(a)
e = self_adjoint_eig(a, compute_v=False)
```
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input,
DefaultValuedAttr<BoolAttr, "true">:$compute_v
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$e,
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$v
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SeluOp : TF_Op<"Selu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`

View File

@ -1977,13 +1977,55 @@ static LogicalResult Verify(FusedBatchNormOp op) {
return success();
}
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
//===----------------------------------------------------------------------===//
// FusedBatchNormV2Op / FusedBatchNormV3Op
//===----------------------------------------------------------------------===//
template <class Op>
static LogicalResult InferenceFoldOperandsPermutation(
ArrayRef<int64_t> permutation, Op *op) {
// FusedBatchNorm in training mode is a layout sentitive operation, and should
// have already assigned an optimal data format.
if (is_training()) return failure();
if (op->is_training()) return failure();
return ::mlir::TF::FoldOperandsPermutation(permutation, op);
}
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
template <class Op>
static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
// In inference mode FusedBatchNorm is not sensitive to data layout.
if (!op->is_training()) return op->data_format();
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
return op->data_format();
// For f16 data type on devices with Tensor Cores support NHWC data format
// is up to ~2x faster.
auto x_ty = op->x().getType().template cast<TensorType>();
const bool is_f16 = x_ty.getElementType().isF16();
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
// For all other data types prefer NCHW.
return "NCHW";
}
LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
}
LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
return ::mlir::TF::UpdateDataFormat(data_format, this);
}
StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
return ::mlir::TF::GetOptimalLayout(devices, this);
}
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
}
LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
@ -1991,22 +2033,7 @@ LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
}
StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
// In inference mode FusedBatchNorm is not sensitive to data layout.
if (!is_training()) return data_format();
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// For f16 data type on devices with Tensor Cores support NHWC data format
// is up to ~2x faster.
auto x_ty = x().getType().cast<TensorType>();
const bool is_f16 = x_ty.getElementType().isF16();
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
// For all other data types prefer NCHW.
return "NCHW";
return ::mlir::TF::GetOptimalLayout(devices, this);
}
//===----------------------------------------------------------------------===//
@ -2156,10 +2183,6 @@ static LogicalResult Verify(IfRegionOp op) {
return failure();
if (failed(VerifyRegionResults(op, op.else_branch(), "else")))
return failure();
if (op.then_branch().front().getNumArguments() != 0)
return op.emitOpError() << "then region cannot have any arguments";
if (op.else_branch().front().getNumArguments() != 0)
return op.emitOpError() << "else region cannot have any arguments";
return success();
}

View File

@ -247,7 +247,7 @@ def TF_YieldOp : TF_Op<"Yield",
}
def TF_IfRegionOp : TF_Op<"IfRegion",
[SingleBlockImplicitTerminator<"YieldOp">]> {
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
let summary = "output = cond ? then_branch output : else_branch output";
let description = [{
@ -524,8 +524,8 @@ are sparse matrices.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$a,
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$b,
TensorOf<[BF16, F32]>:$a,
TensorOf<[BF16, F32]>:$b,
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
@ -535,7 +535,7 @@ are sparse matrices.
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$product
TensorOf<[F32]>:$product
);
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;

View File

@ -443,7 +443,7 @@ func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
}
// Do not fold if total result size is large (>128 KB) and more than 2 times
// Do not fold if total result size is large (>256 KB) and more than 2 times
// the size of operands.
// LINT.IfChange(folding-policy-test)

View File

@ -1055,7 +1055,7 @@ func @testIfRegionThenConsumingElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
// The regions for IfRegion themselves cannot have any arguments
func @testInvalidIfRegionThenArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
// expected-error @+1 {{then region cannot have any arguments}}
// expected-error @+1 {{'tf.IfRegion' op region #0 should have no arguments}}
%0 = "tf.IfRegion"(%arg0) ({
^bb(%arg_bb: tensor<2xf32>):
%t = "tf.Abs"(%arg_bb) : (tensor<2xf32>) -> tensor<2xf32>
@ -1072,7 +1072,7 @@ func @testInvalidIfRegionThenArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
func @testInvalidIfRegionElseArg(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
// expected-error @+1 {{else region cannot have any arguments}}
// expected-error @+1 {{'tf.IfRegion' op region #1 should have no arguments}}
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%neg) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()

View File

@ -40,7 +40,7 @@ namespace TF {
// LINT.IfChange(folding-policy)
static bool ShouldBeFolded(Operation* inst) {
constexpr int kSizeFactor = 2;
constexpr int64_t kSizeThreshold = (1 << 20); // 128 KB
constexpr int64_t kSizeThreshold = (1 << 21); // 256 KB
bool has_unknown_shape = false;
auto get_size = [&](TypeRange types) {
int64_t size = 0;

View File

@ -40,9 +40,6 @@ limitations under the License.
namespace tensorflow {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
@ -98,7 +95,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
context);
}
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -112,13 +109,11 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
enable_shape_inference, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
@ -128,18 +123,17 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
if (!load_status.ok()) {
LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
<< "': " << load_status;
return nullptr;
return load_status;
}
auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
@ -154,19 +148,18 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
if (!load_status.ok()) {
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
<< "': " << load_status;
return nullptr;
return load_status;
}
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context,
upgrade_legacy);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -180,7 +173,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
enable_shape_inference, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;
return module_or.status();
}
auto& module = module_or.ValueOrDie();
std::srand(0);
@ -215,7 +208,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
}
}
}
return module_or.ConsumeValueOrDie();
return module_or;
}
} // namespace tensorflow

View File

@ -23,15 +23,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
// TODO(antiagainst): Directly manipulating files in library functions is not
// a good idea. We should pass in a string/stream here.
// Converts a TensorFlow GraphDef stored in the file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -42,7 +47,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
// Similar as the above function, but replaces all constant tensors
// with randomly generated splat values.
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -54,7 +59,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
// Converts a TensorFlow SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
@ -62,7 +67,7 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,

View File

@ -42,11 +42,13 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
MLIRContext* context) {
return tensorflow::GraphdefToMlirTranslateFunction(
auto module_or = tensorflow::GraphdefToMlirTranslateFunction(
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, control_output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
enable_shape_inference, context);
if (!module_or.status().ok()) return nullptr;
return module_or.ConsumeValueOrDie();
}
static TranslateToMLIRRegistration GraphdefToMlirTranslate(
@ -54,11 +56,13 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, MLIRContext* context) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction(
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, control_output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
enable_shape_inference, context);
if (!module_or.status().ok()) return nullptr;
return module_or.ConsumeValueOrDie();
}
static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(

View File

@ -112,19 +112,19 @@ int main(int argc, char** argv) {
if (import_saved_model_object_graph) {
mlir::MLIRContext context;
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, &context);
if (!module) return 1;
if (!module_or.status().ok()) return 1;
module->print(output->os());
module_or.ConsumeValueOrDie()->print(output->os());
} else if (import_saved_model_signature_defs) {
mlir::MLIRContext context;
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, &context);
if (!module) return 1;
if (!module_or.status().ok()) return 1;
module->print(output->os());
module_or.ConsumeValueOrDie()->print(output->os());
} else {
auto input = mlir::openInputFile(input_filename, &error_message);

View File

@ -129,20 +129,18 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (import_saved_model) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
if (!module_or.status().ok()) return module_or.status();
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
return module_or.ConsumeValueOrDie();
} else if (import_saved_model_v1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
if (!module_or.status().ok()) return module_or.status();
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
return module_or.ConsumeValueOrDie();
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");

View File

@ -26,6 +26,28 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>,
return %0#0 : tensor<8x8x8x8xf32>
}
// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same
// code), so only do a couple of basic checks.
// CHECK-LABEL: fusedBatchNormV2_noTraining
func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
%0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0#0 : tensor<8x8x8x8xf32>
}
// CHECK-LABEL: fusedBatchNormV2_training
func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
%0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
// CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
// CHECK: mhlo.constant
// CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
return %0#0 : tensor<8x8x8x8xf32>
}
// CHECK-LABEL: fusedBatchNormV3_noTraining
func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
@ -956,6 +978,41 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
return %0: tensor<3x5xf32>
}
// SparseMatMul where one operand needs to be transposed and the other one not.
//
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
// CHECK-SAME: permutation = dense<[1, 0]>
// CHECK-SAME: -> tensor<4x5xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32>
}
// SparseMatMul where one operand needs to be casted and the other one not.
//
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
// CHECK-SAME: -> tensor<4x5xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32>
}
//===----------------------------------------------------------------------===//
// MatrixBandPart op legalizations.
//===----------------------------------------------------------------------===//

View File

@ -1531,23 +1531,23 @@ using ConvertFusedBatchNormGradV3Op =
// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
// HLO BatchNormInferenceOp, depending on the value of the 'is_training'
// parameter.
class ConvertFusedBatchNormV3Op
: public OpRewritePattern<TF::FusedBatchNormV3Op> {
template <typename FusedBatchNormOpT>
class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
public:
using OpRewritePattern::OpRewritePattern;
using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::FusedBatchNormV3Op op,
LogicalResult matchAndRewrite(FusedBatchNormOpT op,
PatternRewriter &rewriter) const override {
auto feature_dim =
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
auto input_type_tensor = op.x().getType().cast<TensorType>();
auto input_type_tensor = op.x().getType().template cast<TensorType>();
auto input_element_type = input_type_tensor.getElementType();
auto scale_type_tensor = op.scale().getType().cast<TensorType>();
auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
auto scale_element_type = scale_type_tensor.getElementType();
auto mean_type_tensor = op.mean().getType().cast<TensorType>();
auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
auto mean_element_type = mean_type_tensor.getElementType();
// In the training case, dimensions of input tensors must be static.
if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
@ -1561,7 +1561,7 @@ class ConvertFusedBatchNormV3Op
Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
scale_element_type);
TensorType bn_train_input_type_tensor =
bn_train_input.getType().cast<TensorType>();
bn_train_input.getType().template cast<TensorType>();
if (op.is_training()) {
// Training case.
@ -1643,17 +1643,25 @@ class ConvertFusedBatchNormV3Op
/*broadcast_dimensions=*/DenseIntElementsAttr());
}
// TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are
// currently marked as "reserved spaces 1 and 2". They are used to
// pass the per-batch mean and variance to the gradiant. Here we
// maintain the same behavior by setting them to the mean and variance
// calculated by BatchNormTraining. Output 5 is unused; it doesn't
// matter what we pass there.
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
/*batch_variance=*/corrected_variance,
/*reserve_space_1=*/reserve_space_1,
/*reserve_space_2=*/batch_variance,
/*reserve_space_3=*/op.x()});
if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
// FusedBatchNormV2 expects 4 outputs.
// Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
// They are used to pass the per-batch mean and variance to the
// gradiant. Here we maintain the same behavior by setting them to the
// mean and variance calculated by BatchNormTraining.
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
/*batch_variance=*/corrected_variance,
/*reserve_space_1=*/reserve_space_1,
/*reserve_space_2=*/batch_variance});
} else { // TF::FusedBatchNormV3Op
// FusedBatchNormV3 expects a 5th output, but the output is unused; it
// doesn't matter what we pass there.
rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
/*batch_variance=*/corrected_variance,
/*reserve_space_1=*/reserve_space_1,
/*reserve_space_2=*/batch_variance,
/*reserve_space_3=*/op.x()});
}
} else { // Inference case.
auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
op.getLoc(),
@ -1670,31 +1678,45 @@ class ConvertFusedBatchNormV3Op
// not used for inference. It doesn't matter what values we provide for
// the last 5 results as long as they are of the same type. Forward
// input mean and variance to output mean, variance, reserved_space_1 and
// reserver_space_2. Create a constant tensor to forward to last
// reserve_space_3 output.
auto reserve_space_3_type = op.getResult(5).getType().cast<TensorType>();
int num_elements = reserve_space_3_type.hasStaticShape()
? reserve_space_3_type.getNumElements()
: 0;
auto const_attr_type = RankedTensorType::get(
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
Value dummy_const = rewriter.create<ConstOp>(
op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
if (const_attr_type != reserve_space_3_type)
dummy_const = rewriter.create<TensorCastOp>(
op.getLoc(), reserve_space_3_type, dummy_const);
rewriter.replaceOp(op, {/*y=*/y_out,
/*batch_mean=*/op.mean(),
/*batch_variance=*/op.variance(),
/*reserve_space_1=*/op.mean(),
/*reserve_space_2=*/op.variance(),
/*reserve_space_3=*/dummy_const});
// reserved_space_2.
if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
rewriter.replaceOp(op, {/*y=*/y_out,
/*batch_mean=*/op.mean(),
/*batch_variance=*/op.variance(),
/*reserve_space_1=*/op.mean(),
/*reserve_space_2=*/op.variance()});
} else {
// For FusedBatchNormV3Op, also create a constant tensor to forward to
// last reserve_space_3 output.
auto reserve_space_3_type =
op.getResult(5).getType().template cast<TensorType>();
int num_elements = reserve_space_3_type.hasStaticShape()
? reserve_space_3_type.getNumElements()
: 0;
auto const_attr_type = RankedTensorType::get(
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
Value dummy_const = rewriter.create<ConstOp>(
op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
if (const_attr_type != reserve_space_3_type)
dummy_const = rewriter.create<TensorCastOp>(
op.getLoc(), reserve_space_3_type, dummy_const);
rewriter.replaceOp(op, {/*y=*/y_out,
/*batch_mean=*/op.mean(),
/*batch_variance=*/op.variance(),
/*reserve_space_1=*/op.mean(),
/*reserve_space_2=*/op.variance(),
/*reserve_space_3=*/dummy_const});
}
}
return success();
}
};
using ConvertFusedBatchNormV2Op =
ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
using ConvertFusedBatchNormV3Op =
ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
using PaddingArray =
std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
@ -5358,6 +5380,50 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
// hints, since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
@ -5437,19 +5503,21 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertAvgPool2DGradOp,
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
ConvertAvgPoolOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp,
ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp,
ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp,
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext());

View File

@ -339,18 +339,6 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
(TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
/*precision_config=*/(NullArrayAttr))>;
//===----------------------------------------------------------------------===//
// SparseMatMul op patterns.
//===----------------------------------------------------------------------===//
// Ignores the sparseness hints and translates tf.SparseMatMul to tf.MatMul
// until we will have an implementation that can use the information.
def SparseMatMulToMatMul : Pat<(TF_SparseMatMulOp $a, $b, $a_sparse, $b_sparse,
$transpose_a, $transpose_b),
(TF_MatMulOp $a, $b, $transpose_a,
$transpose_b)>;
//===----------------------------------------------------------------------===//
// MatrixBandPart op pattern.
//===----------------------------------------------------------------------===//

View File

@ -704,6 +704,7 @@ cc_library(
srcs = ["mlir_bridge_pass.cc"],
hdrs = ["mlir_bridge_pass.h"],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:core_cpu",

View File

@ -56,7 +56,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// Skip function graphs as MlirBridgePass will be used instead.
if (options.is_function_graph) return Status::OK();
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
if (!IsEnabled(options.session_options->config)) {
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
mlir_bridge_gauge_v1->GetCell()->Set(false);
return Status::OK();

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#include "llvm/ADT/StringRef.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
namespace tensorflow {
@ -45,7 +46,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
llvm::StringRef name() const override { return "bridge"; }
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge();
return config_proto.experimental().enable_mlir_bridge() ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule

View File

@ -25,6 +25,8 @@ upper_tabs:
path: /xla/operation_semantics
- title: Shapes and layout
path: /xla/shapes
- title: Aliasing
path: /xla/aliasing
- title: Tiled layout
path: /xla/tiled_layout
- title: Use AOT compilation

View File

@ -0,0 +1,73 @@
# Aliasing in XLA
This document describes the aliasing API for XLA: when building an XLA program,
you can specify the desired aliasing between the input and output buffers.
## Defining aliasing at compile-time
For example, consider a trivial HLO module which simply adds `1` to its input:
```
HloModule increment
ENTRY entry {
%p = f32[] parameter(0)
%c = f32[] constant(1)
ROOT %out = f32[] add(%p, %c)
}
```
This module will allocate two 4-byte buffers: one for the input `%p`, and one
for the output `%out`.
However, it is often desirable to perform the update in-place (for example, if
in the frontend generating the expression the input variable is no longer alive
after the computation, as in the increment `p++`).
To perform such an update efficiently, you can specify the input aliasing:
```
HloModule increment, input_output_alias={ {}: 0 }
ENTRY entry {
%p = f32[] parameter(0)
%c = f32[] constant(1)
ROOT %out = f32[] add(%p, %c)
}
```
The format specifies that the entire output (marked by `{}`) is aliased to the
input parameter `0`.
See the
[`XlaBuilder::SetUpAlias`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
API to specify the aliasing programmatically.
## Defining aliasing at run-time
The aliasing defined in the previous step is specified during the _compilation_.
During the execution, you can choose whether actually to donate the buffer using
the
[`LocalClient::RunAsync`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/local_client.h)
API.
Input buffers to the program are wrapped in
[`ExecutionInput`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/executable.h),
which in turn contain a tree of `MaybeOwningDeviceMemory`. If memory is
specified as _owning_ (ownership of the buffer is passed to the XLA runtime),
the buffer is actually donated, and the update is executed in-place, as
requested by the compile-time aliasing API.
If, however, the buffer which is aliased at compile time is _not_ donated at
runtime, _copy-protection_ kicks in: an extra output buffer `O` is allocated,
and the contents of the input buffer `P` which was meant to be aliased are
copied into `O` (so effectively the program can execute as if the buffer `O` was
donated at runtime).
## Frontend interop
### TF/XLA
In clusters of TensorFlow program compiled with XLA, all resource variable
updates are aliased at compile time (the aliasing at runtime depends on whether
anything else holds a reference to the resource variable tensor).

View File

@ -24,6 +24,7 @@ tf_proto_library_cc(
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
create_service = True,
protodeps = [
"//tensorflow/compiler/xla:xla_proto",
],

View File

@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
custom_call_handler_);
}
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
HloInstruction* replace, HloInstruction* with) {
CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
for (auto& kv : dynamic_mapping_) {
if (kv.second == replace) {
kv.second = with;
}
}
}
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
HloInstruction* new_inst,
const ShapeIndex& index) {

View File

@ -68,6 +68,11 @@ class DynamicDimensionInference {
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
}
// For all tensors whose dynamic dimension is `replace`, replace them with
// `with`.
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
HloInstruction* with);
friend class DynamicDimensionInferenceVisitor;
private:

View File

@ -242,7 +242,6 @@ cc_library(
deps = [
":backend_configs_cc",
":buffer_allocations",
":cudnn_batchnorm_runner",
":elemental_ir_emitter",
":gpu_constants",
":gpu_conv_runner",
@ -267,6 +266,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service:while_loop_analysis",
@ -282,7 +282,6 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:sort_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",

View File

@ -31,13 +31,13 @@ limitations under the License.
namespace xla {
namespace gpu {
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type, int64 batch_size, int64 n,
const HloInstruction* hlo)
: Thunk(Kind::kCholesky, hlo),
PrimitiveType type, int64 batch_size, int64 n)
: Thunk(Kind::kCholesky, thunk_info),
uplo_(options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper),
a_buffer_(a_buffer),
@ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
info_buffer_(info_buffer),
type_(type),
batch_size_(batch_size),
a_batch_stride_(n * n *
ShapeUtil::ByteSizeOfPrimitiveType(
hlo->operand(0)->shape().element_type())),
a_batch_stride_(
n * n *
ShapeUtil::ByteSizeOfPrimitiveType(
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
n_(n) {}
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {

View File

@ -41,12 +41,11 @@ namespace gpu {
class CholeskyThunk : public Thunk {
public:
static StatusOr<int64> ScratchBufferSize(int64 n);
CholeskyThunk(const CholeskyOptions& options,
CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type,
int64 batch_size, int64 n, const HloInstruction* hlo);
BufferAllocation::Slice info_buffer, PrimitiveType type,
int64 batch_size, int64 n);
CholeskyThunk(const CholeskyThunk&) = delete;
CholeskyThunk& operator=(const CholeskyThunk&) = delete;

View File

@ -218,14 +218,14 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
} // anonymous namespace
CollectivePermuteThunk::CollectivePermuteThunk(
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest,
const HloInstruction* instr)
: Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {}
ThunkInfo thunk_info, const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest)
: Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {}
Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction());
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
// Rendezvous with the threads for all other devices that are participating in
// this CollectivePermute.

View File

@ -26,9 +26,9 @@ namespace gpu {
// Thunk that implements the collective-permute HLO.
class CollectivePermuteThunk : public Thunk {
public:
CollectivePermuteThunk(const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest,
const HloInstruction* instr);
CollectivePermuteThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -24,12 +24,14 @@ namespace xla {
namespace gpu {
ConditionalThunk::ConditionalThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo)
: Thunk(Kind::kConditional, hlo),
branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED),
std::vector<ThunkSequence> branch_thunk_sequences)
: Thunk(Kind::kConditional, thunk_info),
branch_index_is_bool_(
thunk_info.hlo_instruction->operand(0)->shape().element_type() ==
PRED),
branch_index_buffer_index_(branch_index_buffer_index),
branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(),
branch_operand_buffer_indexes.end()) {
@ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk(
branch_thunks_.reserve(branch_thunk_sequences.size());
for (auto& branch_thunk_sequence : branch_thunk_sequences) {
branch_thunks_.emplace_back(
new SequentialThunk(std::move(branch_thunk_sequence), nullptr));
new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence)));
}
}
@ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& profiler = *params.profiler;
auto& stream = *params.stream;
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
// Copy the predicate value from device.
int32 branch_index = -1;
bool pred = false;

View File

@ -43,10 +43,10 @@ namespace gpu {
class ConditionalThunk : public Thunk {
public:
ConditionalThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo);
std::vector<ThunkSequence> branch_thunk_sequences);
ConditionalThunk(const ConditionalThunk&) = delete;
ConditionalThunk& operator=(const ConditionalThunk&) = delete;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@ -30,16 +31,16 @@ namespace xla {
namespace gpu {
ConvolutionThunk::ConvolutionThunk(
const HloCustomCallInstruction* cudnn_call,
std::vector<BufferAllocation::Slice> operand_slices,
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
BufferAllocation::Slice tuple_result_slice)
: Thunk(Kind::kConvolution, cudnn_call),
cudnn_call_(cudnn_call),
: Thunk(Kind::kConvolution, thunk_info),
operand_buffers_(std::move(operand_slices)),
result_buffer_(result_slice),
scratch_buffer_(scratch_slice),
tuple_result_buffer_(tuple_result_slice) {}
tuple_result_buffer_(tuple_result_slice) {
cudnn_call_ = Cast<HloCustomCallInstruction>(hlo_instruction());
}
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
const auto& buffer_allocations = *params.buffer_allocations;
@ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(scratch_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers),
result_buffer, scratch, params.stream));

View File

@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
// operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
ConvolutionThunk(ThunkInfo thunk_info,
std::vector<BufferAllocation::Slice> operand_slices,
BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,

View File

@ -22,10 +22,9 @@ namespace xla {
namespace gpu {
HostToDeviceCopyThunk::HostToDeviceCopyThunk(
const void* source_address,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kCopy, hlo_instruction),
ThunkInfo thunk_info, const void* source_address,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
: Thunk(Kind::kCopy, thunk_info),
source_address_(source_address),
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
@ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase destination_data =
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
return Status::OK();
}
DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kCopy, hlo_instruction),
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
: Thunk(Kind::kCopy, thunk_info),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
@ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase source_data =
params.buffer_allocations->GetDeviceAddress(source_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemcpy(&destination_data, source_data, mem_size_);
return Status::OK();
}

View File

@ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk {
// Constructs a CopyThunk that copies host data from `source_address` to the
// device buffer `destination_buffer`. `mem_size` is the size of the data in
// bytes.
HostToDeviceCopyThunk(const void* source_address,
HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address,
const BufferAllocation::Slice& destination_buffer,
uint64 mem_size, const HloInstruction* hlo_instruction);
uint64 mem_size);
HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete;
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
@ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk {
// Constructs a CopyThunk that copies host data from `source_buffer` to the
// device buffer `destination_buffer`. `mem_size` is the size of the data in
// bytes.
DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer,
DeviceToDeviceCopyThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
uint64 mem_size,
const HloInstruction* hlo_instruction);
uint64 mem_size);
DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete;
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;

View File

@ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
} // namespace
CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
const BufferAllocation::Slice& output, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
const BufferAllocation::Slice& output)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
operand_(operand),
scale_(scale),
offset_(offset),
@ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
epsilon_(epsilon),
feature_index_(feature_index),
output_(output) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(),
kCudnnBatchNormForwardInferenceCallTarget);
@ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
const ExecuteParams& params) {
auto& buffer_allocations = *params.buffer_allocations;
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::DeviceMemoryBase output_base =
buffer_allocations.GetDeviceAddress(output_);
se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
@ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
}
CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
operand_(operand),
scale_(scale),
offset_(offset),
@ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
output_mean_(output_mean),
output_inv_stddev_(output_inv_stddev),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
@ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
se::DeviceMemory<float> null_device_ptr(nullptr);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto& stream = *params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
hlo_instruction(), operand, output_data, output_mean, output_inv_stddev,
@ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
}
CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output, float epsilon,
int64 feature_index, const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
operand_(operand),
scale_(scale),
mean_(mean),
@ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
output_grad_scale_(output_grad_scale),
output_grad_offset_(output_grad_offset),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
@ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(output_grad_offset_));
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::Stream* stream = params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
hlo_instruction(), operand, output_grad_data, grad_output,

View File

@ -46,14 +46,14 @@ namespace gpu {
class CudnnBatchNormForwardInferenceThunk : public Thunk {
public:
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand,
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output,
const HloInstruction* hlo);
const BufferAllocation::Slice& output);
CudnnBatchNormForwardInferenceThunk(
const CudnnBatchNormForwardInferenceThunk&) = delete;
@ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
class CudnnBatchNormForwardTrainingThunk : public Thunk {
public:
CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo);
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormForwardTrainingThunk(
const CudnnBatchNormForwardTrainingThunk&) = delete;
@ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
class CudnnBatchNormBackwardThunk : public Thunk {
public:
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand,
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
@ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple,
const HloInstruction* hlo);
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =

View File

@ -22,15 +22,15 @@ namespace xla {
namespace gpu {
CustomCallThunk::CustomCallThunk(
void* call_target,
ThunkInfo thunk_info, void* call_target,
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
const HloInstruction* instr)
: Thunk(Thunk::kCustomCall, instr),
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
: Thunk(Thunk::kCustomCall, thunk_info),
call_target_(call_target),
operand_slices_(std::move(operand_slices)),
result_slices_(std::move(result_slices)),
opaque_(std::move(opaque)) {
const HloInstruction* instr = hlo_instruction();
CHECK_EQ(instr->operand_count(), operand_slices_.size());
for (int64 i = 0; i < instr->operand_count(); ++i) {
const auto& s1 = operand_slices_[i].shape();

View File

@ -39,10 +39,9 @@ namespace gpu {
class CustomCallThunk : public Thunk {
public:
CustomCallThunk(
void* call_target,
ThunkInfo thunk_info, void* call_target,
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
const HloInstruction* instr);
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
struct NcclAllReduceThunk::AuxData {};
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
ThunkInfo thunk_info, int64 replica_count,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
replica_count_(replica_count),
buffers_(std::move(buffers)) {}

View File

@ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) {
} // namespace
FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
const HloInstruction* hlo)
: Thunk(Kind::kFft, hlo),
const Shape& input_shape, const Shape& output_shape)
: Thunk(Kind::kFft, thunk_info),
fft_type_(
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
input_shape.element_type() == C128)),
@ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.memory_allocator());
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
if (fft_plan_ == nullptr) {
const int64 fft_rank = fft_length_.size();
CHECK_LE(fft_rank, 3);

View File

@ -62,11 +62,11 @@ class FftThunk : public Thunk {
public:
// Constructs a thunk for launching an FFT on a stream.
// Semantics of null hlo_instruction argument are as in Thunk.
FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
FftThunk(ThunkInfo thunk_info, FftType fft_type,
absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
const HloInstruction* hlo);
const Shape& input_shape, const Shape& output_shape);
FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_

View File

@ -23,16 +23,15 @@ limitations under the License.
namespace xla {
namespace gpu {
ForThunk::ForThunk(const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence)
: Thunk(Kind::kWhile, thunk_info),
loop_limit_(loop_limit),
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
std::move(*body_thunk_sequence), nullptr)) {}
ThunkInfo(), std::move(*body_thunk_sequence))) {}
void ForThunk::ComputeAnnotations() {
Thunk::ComputeAnnotations();
@ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
for (int64 i = 0; i < loop_limit_; ++i) {
params.profiler->StartHloComputation();
// Invoke loop body thunk sequence.

View File

@ -31,9 +31,8 @@ namespace gpu {
// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
class ForThunk : public Thunk {
public:
ForThunk(const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo);
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence);
ForThunk(const ForThunk&) = delete;
ForThunk& operator=(const ForThunk&) = delete;

View File

@ -132,6 +132,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer,
stream,
/*implements_whole_instruction=*/true,
/*profile_index=*/-1,
/*profiler=*/nullptr,
/*profile_result=*/&profile_result, algorithm)
.ok());

View File

@ -33,13 +33,13 @@ limitations under the License.
namespace xla {
namespace gpu {
GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
GemmThunk::GemmThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice &lhs_buffer,
const BufferAllocation::Slice &rhs_buffer,
const BufferAllocation::Slice &output_buffer,
bool implements_whole_instruction,
const HloInstruction *hlo_instruction,
const GemmBackendConfig &backend_config)
: Thunk(Kind::kGemm, hlo_instruction),
: Thunk(Kind::kGemm, thunk_info),
lhs_buffer_(lhs_buffer),
rhs_buffer_(rhs_buffer),
output_buffer_(output_buffer),
@ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams &params) {
se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
output_data, params.stream, implements_whole_instruction_,
params.profiler);
profile_index(), params.profiler);
}
// This struct contains the metadata of a matrix, e.g., its base address and
@ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm,
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
se::DeviceMemoryBase output_buffer, se::Stream *stream,
bool implements_whole_instruction,
absl::optional<int64> profile_index,
HloExecutionProfiler *profiler,
se::blas::ProfileResult *profile_result,
absl::optional<se::blas::AlgorithmType> algorithm) {
@ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm,
rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
std::unique_ptr<ScopedInstructionProfiler> op_profiler =
profiler ? profiler->MakeScopedInstructionProfiler(
implements_whole_instruction ? gemm : nullptr)
implements_whole_instruction ? profile_index : -1)
: nullptr;
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {

View File

@ -39,11 +39,10 @@ class GemmThunk : public Thunk {
public:
// Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
// BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
bool implements_whole_instruction,
const HloInstruction* hlo_instruction,
const GemmBackendConfig& backend_config);
GemmThunk(const GemmThunk&) = delete;
@ -72,7 +71,8 @@ Status RunGemm(
const HloInstruction* gemm, const GemmBackendConfig& backend_config,
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
se::DeviceMemoryBase output_buffer, se::Stream* stream,
bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr,
bool implements_whole_instruction, absl::optional<int64> profile_index,
HloExecutionProfiler* profiler = nullptr,
se::blas::ProfileResult* profile_result = nullptr,
absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt);

View File

@ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl(
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
int pointer_size, std::unique_ptr<llvm::Module>* llvm_module,
int pointer_size, const HloProfileIndexMap* profile_index_map,
std::unique_ptr<llvm::Module>* llvm_module,
std::unique_ptr<BufferAssignment>* buffer_assignment,
std::unique_ptr<ThunkSchedule>* thunk_schedule) {
*llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
@ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl(
IrEmitterContext ir_emitter_context(
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
cuda_compute_capability, llvm_module->get());
cuda_compute_capability, profile_index_map, llvm_module->get());
HloComputation* entry_computation = hlo_module->entry_computation();
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
@ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl(
// not all explicitly checked, but at least we can document them here:
// * The entry HloComputation shall not have dead code (all reachable from
// ROOT).
// * For each visit of HloInstruction, either none or one Thunk will be
// returned.
// * The visited instructions are all instructions in the entry
// computation.
// * For each visit of these HloInstructions, either none or one Thunk
// will be returned.
// * If there is a thunk returned, thunk->hlo_instruction() equals the
// input HloInstruction*.
// * A returned thunk may contain other sub-thunks. A sub-thunk may or may
// not have an associated hlo_instruction().
TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
if (!thunks->empty()) {
auto thunk = std::move(thunks->front());
@ -603,6 +608,25 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
return cuda_compute_capability;
}();
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
VLOG(1) << "HLO memory read+written: "
<< tensorflow::strings::HumanReadableNumBytes(
cost_analysis.bytes_accessed());
if (module->config().hlo_profiling_enabled()) {
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
module->entry_computation()->name());
}
}
std::unique_ptr<llvm::Module> llvm_module;
std::unique_ptr<BufferAssignment> buffer_assignment;
std::unique_ptr<ThunkSchedule> thunk_schedule;
@ -610,8 +634,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
module.get(), &llvm_context, target_triple_, data_layout_,
stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment,
&thunk_schedule));
GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
&buffer_assignment, &thunk_schedule));
if (user_pre_optimization_hook_) {
user_pre_optimization_hook_(*llvm_module);
@ -653,25 +677,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
thunk_schedule->ToString());
}
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
VLOG(1) << "HLO memory read+written: "
<< tensorflow::strings::HumanReadableNumBytes(
cost_analysis.bytes_accessed());
if (module->config().hlo_profiling_enabled()) {
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
module->entry_computation()->name());
}
}
auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version,
std::move(thunk_schedule), std::move(module),
@ -709,7 +714,8 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
hlo_module, llvm_context, target_triple, data_layout, platform_name,
gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule));
pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
&buffer_assignment, &thunk_schedule));
return llvm_module;
}
} // namespace gpu

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() {
}
}
void HloExecutionProfiler::FinishHloInstruction(
const HloInstruction* hlo_instruction) {
void HloExecutionProfiler::FinishHloInstruction(size_t index) {
if (do_profile_) {
hlo_instructions_.erase(hlo_instruction);
profile_->SetCyclesTakenBy(
hlo_instruction,
GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
indices_.erase(index);
profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_,
stream_, clock_rate_ghz_));
}
}
std::unique_ptr<ScopedInstructionProfiler>
HloExecutionProfiler::MakeScopedInstructionProfiler(
const HloInstruction* hlo_instruction) {
if (do_profile_ && hlo_instruction != nullptr) {
absl::optional<int64> index) {
if (do_profile_ && index.has_value()) {
// Make sure that we are not already measuring the time for the same
// 'hlo_instruction'.
CHECK(hlo_instructions_.insert(hlo_instruction).second)
<< hlo_instruction->name();
// instruction.
// TODO(timshen): provide more useful printout.
CHECK(indices_.insert(*index).second) << *index;
}
return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
return absl::make_unique<ScopedInstructionProfiler>(this, index);
}
} // namespace gpu

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -58,14 +57,17 @@ class HloExecutionProfiler {
void StartHloInstruction();
// If profiling is enabled, stops the per-operation timer and records the time
// that the hlo_instruction took to execute in the profile.
void FinishHloInstruction(const HloInstruction* hlo_instruction);
// that at `profile_index`. Profile indices can be looked up from
// HloProfileIndexMap.
void FinishHloInstruction(size_t profile_index);
// Returns a ScopedInstructionProfiler and triggers a call to
// StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
// out of scope, it triggers a call to FinishHloInstruction().
//
// If profile_index < 0, it results in a no-op.
std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
const HloInstruction* hlo_instruction);
absl::optional<int64> profile_index);
private:
const bool do_profile_;
@ -77,7 +79,7 @@ class HloExecutionProfiler {
std::stack<std::unique_ptr<se::Timer>> timers_;
// Contains the HLO instructions for which we are currently measuring the
// time.
std::unordered_set<const HloInstruction*> hlo_instructions_;
std::unordered_set<size_t> indices_;
bool finished_execution_ = false;
};
@ -87,21 +89,21 @@ class HloExecutionProfiler {
class ScopedInstructionProfiler {
public:
ScopedInstructionProfiler(HloExecutionProfiler* profiler,
const HloInstruction* hlo_instruction)
: profiler_(profiler), hlo_instruction_(hlo_instruction) {
if (hlo_instruction != nullptr) {
absl::optional<int64> index)
: profiler_(profiler), index_(index) {
if (index_.has_value()) {
profiler->StartHloInstruction();
}
}
~ScopedInstructionProfiler() {
if (hlo_instruction_ != nullptr) {
profiler_->FinishHloInstruction(hlo_instruction_);
if (index_.has_value()) {
profiler_->FinishHloInstruction(*index_);
}
}
private:
HloExecutionProfiler* profiler_;
const HloInstruction* hlo_instruction_;
absl::optional<int64> index_;
};
} // namespace gpu

View File

@ -23,9 +23,9 @@ namespace xla {
namespace gpu {
InfeedThunk::InfeedThunk(
const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
ThunkInfo thunk_info,
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
: Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream;
@ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
ShapeTree<InfeedBuffer> infeed_buffers =
GetOrCreateInfeedManager()->BlockingGetNextDestination();

View File

@ -34,8 +34,8 @@ class InfeedThunk : public Thunk {
public:
// Constructs a InfeedThunk that copies data from the on-device
// infeed queue into the buffers in the given shape tree.
InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction);
InfeedThunk(ThunkInfo thunk_info,
const ShapeTree<BufferAllocation::Slice>& infeed_slices);
InfeedThunk(const InfeedThunk&) = delete;
InfeedThunk& operator=(const InfeedThunk&) = delete;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
namespace xla {
@ -33,12 +34,13 @@ class IrEmitterContext {
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
std::string platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
llvm::Module* llvm_module)
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
: hlo_module_(hlo_module),
buffer_assignment_(buffer_assignment),
platform_name_(std::move(platform_name)),
gpu_device_info_(gpu_device_info),
cuda_compute_capability_(cuda_compute_capability),
profile_index_map_(profile_index_map),
llvm_module_(llvm_module) {}
// Disallow copy and assign.
IrEmitterContext(const IrEmitterContext&) = delete;
@ -54,6 +56,7 @@ class IrEmitterContext {
absl::optional<CudaComputeCapability> cuda_compute_capability() const {
return cuda_compute_capability_;
}
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
llvm::Module* llvm_module() { return llvm_module_; }
NameUniquer* name_uniquer() { return &name_uniquer_; }
@ -63,6 +66,7 @@ class IrEmitterContext {
std::string platform_name_;
GpuDeviceInfo gpu_device_info_;
absl::optional<CudaComputeCapability> cuda_compute_capability_;
const HloProfileIndexMap* profile_index_map_;
llvm::Module* llvm_module_;
NameUniquer name_uniquer_;
};

View File

@ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
/*updates_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(2))));
}
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(fusion), std::move(thunks)));
return Status::OK();
}
// In the case of root tuple, it can be either reduce or slice input
@ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
auto destination_buffer = GetAllocationSlice(*copy);
if (operand_buffer != destination_buffer) {
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
GetThunkInfo(copy),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/
ByteSizeOf(copy->operand(0)->shape()), copy));
ByteSizeOf(copy->operand(0)->shape())));
}
return Status::OK();
}
@ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
}
AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
GetThunkInfo(tuple), tuple_element_buffers,
GetAllocationSlice(*tuple)));
return Status::OK();
}
AddThunkToThunkSequence(
@ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
thunks.push_back(BuildKernelThunk(select_and_scatter,
/*implements_whole_instruction=*/false));
std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
absl::make_unique<SequentialThunk>(GetThunkInfo(select_and_scatter),
std::move(thunks));
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
@ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
auto destination_buffer = GetAllocationSlice(*scatter);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
/*hlo_instruction=*/nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
}
thunks.push_back(
@ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(scatter), std::move(thunks)));
}
return Status::OK();
@ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/source_address,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
}
}
@ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
}
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(sort), std::move(thunks)));
if (sort->operand_count() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
@ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
}
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
AddThunkToThunkSequence(
absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));
return Status::OK();
}
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo)));
return Status::OK();
}
@ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
GetThunkInfo(crs),
/*replica_count=*/hlo_module_config_.replica_count(),
/*buffers=*/std::move(buffers), crs);
/*buffers=*/std::move(buffers));
if (crs->shape().IsTuple()) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(all_reduce_thunk));
thunks.push_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(crs), std::move(thunks)));
} else {
AddThunkToThunkSequence(std::move(all_reduce_thunk));
}
@ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
CHECK(crs->operand(0)->shape().IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
GetThunkInfo(crs),
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape())));
return Status::OK();
}
@ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
.GetUniqueSlice(crs, {i})
.ValueOrDie());
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape())));
}
// Output a tuple of the buffers above.
thunks.push_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
absl::make_unique<SequentialThunk>(GetThunkInfo(crs), std::move(thunks)));
return Status::OK();
}
@ -1787,8 +1794,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return absl::make_unique<KernelThunk>(
non_constant_buffers, std::string(kernel->getName()),
implements_whole_instruction ? inst : nullptr);
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
non_constant_buffers, std::string(kernel->getName()));
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
@ -1838,8 +1845,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
nullptr)};
return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(),
GetAllocationSlice(*hlo, index))};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@ -1857,7 +1864,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {absl::make_unique<Memset32BitValueThunk>(
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@ -1868,7 +1875,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {absl::make_unique<Memset32BitValueThunk>(
word, GetAllocationSlice(*hlo, index), nullptr)};
Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))};
}
}
@ -2014,9 +2021,10 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return absl::make_unique<WhileThunk>(
GetThunkInfo(hlo),
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence(), hlo);
ir_emitter_body.ConsumeThunkSequence());
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
@ -2031,8 +2039,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return absl::make_unique<ForThunk>(
loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
ir_emitter_body.ConsumeThunkSequence());
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
@ -2054,8 +2062,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
}
return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks), hlo);
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks));
}
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
@ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
ir_emitter_context_->llvm_module());
thunks.push_back(std::move(kernel_thunk));
auto sequential_thunk =
absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
auto sequential_thunk = absl::make_unique<SequentialThunk>(
GetThunkInfo(unnested_hlo), std::move(thunks));
AddThunkToThunkSequence(std::move(sequential_thunk));
return Status::OK();
@ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
return emit_status;
}
Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
const HloInstruction* hlo) const {
auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
info.profile_index.emplace(
static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
}
return info;
}
} // namespace gpu
} // namespace xla

View File

@ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter,
// Returns the last generated thunk.
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
// The thunk sequence this IrEmitter generates for the input computation.
ThunkSequence thunk_sequence_;

View File

@ -33,10 +33,10 @@ limitations under the License.
namespace xla {
namespace gpu {
KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kKernel, hlo_instruction),
KernelThunk::KernelThunk(ThunkInfo thunk_info,
absl::Span<const BufferAllocation* const> args,
const string& kernel_name)
: Thunk(Kind::kKernel, thunk_info),
args_(args.begin(), args.end()),
kernel_name_(kernel_name) {}
@ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
}
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
return ExecuteKernelOnStream(*kernel, buffer_args,
launch_dimensions.threads_per_block(),
launch_dimensions.block_count(), params.stream);

View File

@ -47,8 +47,9 @@ class KernelThunk : public Thunk {
// Constructs a thunk for the given kernel.
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name, const HloInstruction* hlo_instruction);
KernelThunk(ThunkInfo thunk_info,
absl::Span<const BufferAllocation* const> args,
const string& kernel_name);
KernelThunk(const KernelThunk&) = delete;
KernelThunk& operator=(const KernelThunk&) = delete;
~KernelThunk() override = default;

View File

@ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase dest_data =
params.buffer_allocations->GetDeviceAddress(dest_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemZero(&dest_data, dest_data.size());
return Status::OK();
}
@ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase dest_data =
params.buffer_allocations->GetDeviceAddress(dest_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemset32(&dest_data, value_, dest_data.size());
return Status::OK();
}

View File

@ -32,9 +32,9 @@ namespace gpu {
// Thunk that zeroes out a given chunk of memory.
class MemzeroThunk : public Thunk {
public:
explicit MemzeroThunk(const BufferAllocation::Slice& dest,
const HloInstruction* hlo)
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
explicit MemzeroThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kMemzero, thunk_info), dest_(dest) {}
Status ExecuteOnStream(const ExecuteParams& params) override;
@ -46,10 +46,11 @@ class MemzeroThunk : public Thunk {
// destination chunk must have size divisible by 32 bits.
class Memset32BitValueThunk : public Thunk {
public:
explicit Memset32BitValueThunk(uint32 value,
const BufferAllocation::Slice& dest,
const HloInstruction* hlo)
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kMemset32BitValue, thunk_info),
value_(value),
dest_(dest) {}
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
}
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
ThunkInfo thunk_info, int64 replica_count,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
replica_count_(replica_count),
buffers_(std::move(buffers)),
aux_data_(absl::make_unique<AuxData>()) {
@ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk(
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << "Starting NcclAllReduceThunk.";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
int64 local_device_ordinal = params.stream->parent()->device_ordinal();

View File

@ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk {
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
const HloInstruction* all_reduce);
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
std::vector<Buffer> buffers);
~NcclAllReduceThunk() override;
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -23,9 +23,9 @@ limitations under the License.
namespace xla {
namespace gpu {
OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kOutfeed, hlo_instruction),
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
ShapeTree<BufferAllocation::Slice> outfeed_slices)
: Thunk(Kind::kOutfeed, thunk_info),
outfeed_slices_(std::move(outfeed_slices)) {}
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
@ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
outfeed_manager->BlockingGetNextDestination();

View File

@ -32,8 +32,8 @@ class OutfeedThunk : public Thunk {
public:
// Constructs a OutfeedThunk that copies data to the host-side
// outfeed queue from the buffers in the given shape tree.
OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
const HloInstruction* hlo_instruction);
OutfeedThunk(ThunkInfo thunk_info,
ShapeTree<BufferAllocation::Slice> outfeed_slices);
OutfeedThunk(const OutfeedThunk&) = delete;
OutfeedThunk& operator=(const OutfeedThunk&) = delete;

View File

@ -18,13 +18,13 @@ limitations under the License.
namespace xla {
namespace gpu {
ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest,
const HloInstruction* instr)
: Thunk(Kind::kReplicaId, instr), dest_(dest) {}
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
TF_ASSIGN_OR_RETURN(int replica_id,

View File

@ -26,8 +26,7 @@ namespace gpu {
// Thunk that implements the ReplicaId HLO.
class ReplicaIdThunk : public Thunk {
public:
ReplicaIdThunk(const BufferAllocation::Slice& dest,
const HloInstruction* instr);
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -24,9 +24,9 @@ namespace gpu {
using ::tensorflow::profiler::ScopedAnnotation;
SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
const HloInstruction* hlo)
: Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
SequentialThunk::SequentialThunk(ThunkInfo thunk_info,
std::vector<std::unique_ptr<Thunk>> thunks)
: Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {}
void SequentialThunk::ComputeAnnotations() {
for (const auto& thunk : thunks_) {
@ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
for (const auto& thunk : thunks_) {
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params));

View File

@ -32,8 +32,8 @@ namespace gpu {
// require multiple kernel launches or library calls.
class SequentialThunk : public Thunk {
public:
SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
const HloInstruction* hlo);
SequentialThunk(ThunkInfo thunk_info,
std::vector<std::unique_ptr<Thunk>> thunks);
SequentialThunk(const SequentialThunk&) = delete;
SequentialThunk& operator=(const SequentialThunk&) = delete;

View File

@ -68,13 +68,21 @@ class Thunk {
kWhile,
};
struct ThunkInfo {
const HloInstruction* hlo_instruction = nullptr;
absl::optional<int64> profile_index;
// TODO(timshen): Remove hlo_instruction and add name(),
// profile_annotation() here.
};
// The hlo_instruction argument is meant to be the instruction this thunk was
// generated from, but Thunk never uses this argument other than to save it
// to Thunk::hlo_instruction, so it can be null.
explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
explicit Thunk(Kind kind, ThunkInfo thunk_info)
: kind_(kind),
hlo_instruction_(hlo_instruction),
name_(hlo_instruction_ ? hlo_instruction_->name() : "") {}
hlo_instruction_(thunk_info.hlo_instruction),
name_(hlo_instruction_ ? hlo_instruction_->name() : ""),
profile_index_(thunk_info.profile_index) {}
virtual ~Thunk() {}
Thunk(const Thunk&) = delete;
Thunk& operator=(const Thunk&) = delete;
@ -128,6 +136,8 @@ class Thunk {
protected:
const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
absl::optional<int64> profile_index() const { return profile_index_; }
const HloModuleConfig& GetModuleConfig() const {
return hlo_instruction()->GetModule()->config();
}
@ -146,8 +156,12 @@ class Thunk {
private:
Kind kind_;
// Will be removed in the future, as Thunk is migrating away from the
// monolithic HloInstruction.
const HloInstruction* hlo_instruction_;
std::string name_;
absl::optional<int64> profile_index_;
string profile_annotation_;
};

View File

@ -40,11 +40,11 @@ namespace gpu {
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
return absl::make_unique<FftThunk>(
inst->fft_type(), inst->fft_length(),
context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(),
/*input_buffer=*/GetAllocationSlice(*operand),
/*output_buffer=*/GetAllocationSlice(*inst),
/*input_shape=*/operand->shape(),
/*output_shape=*/inst->shape(), inst);
/*output_shape=*/inst->shape());
}
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
@ -63,11 +63,11 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
: n * n * elem_size;
int64 b_batch_stride = m * n * elem_size;
return absl::make_unique<TriangularSolveThunk>(
inst->triangular_solve_options(),
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
/*a_input_buffer=*/GetAllocationSlice(*a),
/*b_input_buffer=*/GetAllocationSlice(*inst),
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
b_batch_stride, inst);
b_batch_stride);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
@ -86,24 +86,27 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_buffer=*/GetAllocationSlice(*bias),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape())));
thunks.push_back(absl::make_unique<GemmThunk>(
context_->GetThunkInfo(inst),
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
/*implements_whole_instruction=*/false, inst,
std::move(gemm_config)));
return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
/*implements_whole_instruction=*/false, std::move(gemm_config)));
return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst),
std::move(thunks));
}
}
return absl::make_unique<GemmThunk>(
context_->GetThunkInfo(inst),
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
/*implements_whole_instruction=*/true, inst, std::move(gemm_config));
/*implements_whole_instruction=*/true, std::move(gemm_config));
}
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
@ -115,7 +118,7 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
[&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
*slice = GetAllocationSlice(*inst, index);
});
return absl::make_unique<InfeedThunk>(slices, inst);
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst), slices);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
@ -130,7 +133,8 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
*slice = status_or_slice.ValueOrDie();
}
});
return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
std::move(slices));
}
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
@ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*variance=*/GetAllocationSlice(*custom_call->operand(4)),
/*epsilon=*/epsilon_value,
/*feature_index=*/feature_index_value,
/*output=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_inv_stddev = GetAllocationSlice(*custom_call, {2});
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*output_data=*/output_data,
/*output_mean=*/output_mean,
/*output_inv_stddev=*/output_inv_stddev,
/*output_tuple=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_scale = GetAllocationSlice(*custom_call, {1});
auto output_grad_offset = GetAllocationSlice(*custom_call, {2});
AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*mean=*/GetAllocationSlice(*custom_call->operand(2)),
@ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*output_grad_data=*/output_grad_data,
/*output_grad_scale=*/output_grad_scale,
/*output_grad_offset=*/output_grad_offset,
/*output_tuple=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
context_->GetThunkInfo(custom_call), std::move(operand_slices),
conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
@ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (operand_buffer != a_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(custom_call),
/*source_address=*/operand_buffer,
/*destination_buffer=*/a_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
}
thunks.push_back(absl::make_unique<CholeskyThunk>(
options, a_buffer, workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n,
custom_call));
context_->GetThunkInfo(custom_call), options, a_buffer,
workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(custom_call), std::move(thunks)));
}
return Status::OK();
@ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
ShapeTree<BufferAllocation::Slice> result_slices =
get_slices_for_instr(custom_call);
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
call_target, std::move(operand_slices), std::move(result_slices),
Cast<HloCustomCallInstruction>(custom_call)->opaque(), custom_call));
context_->GetThunkInfo(custom_call), call_target,
std::move(operand_slices), std::move(result_slices),
Cast<HloCustomCallInstruction>(custom_call)->opaque()));
return Status::OK();
}
#endif
@ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
auto destination_buffer = GetAllocationSlice(*hlo);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(hlo),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
}
thunks.push_back(BuildTriangularSolveThunk(hlo));
@ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(hlo), std::move(thunks)));
}
return Status::OK();
}
@ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
return Status::OK();
}
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
const HloInstruction* hlo) const {
CHECK(hlo);
Thunk::ThunkInfo info;
info.hlo_instruction = hlo;
return info;
}
} // namespace gpu
} // namespace xla

View File

@ -36,6 +36,7 @@ class ThunkEmitter {
const HloInstruction& hlo, const ShapeIndex& index) const = 0;
virtual int64 ByteSizeOf(const Shape& shape) const = 0;
virtual absl::string_view platform_name() const = 0;
virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
virtual ~EmissionContext() = default;
};

View File

@ -32,12 +32,12 @@ namespace xla {
namespace gpu {
TriangularSolveThunk::TriangularSolveThunk(
const TriangularSolveOptions& options,
ThunkInfo thunk_info, const TriangularSolveOptions& options,
const BufferAllocation::Slice& a_buffer,
const BufferAllocation::Slice& b_buffer, PrimitiveType type,
int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
int64 b_batch_stride, const HloInstruction* hlo)
: Thunk(Kind::kTriangularSolve, hlo),
int64 b_batch_stride)
: Thunk(Kind::kTriangularSolve, thunk_info),
uplo_(options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper),
side_(options.left_side() ? se::blas::Side::kLeft

View File

@ -38,12 +38,12 @@ namespace gpu {
// Thread-compatible.
class TriangularSolveThunk : public Thunk {
public:
TriangularSolveThunk(const TriangularSolveOptions& options,
TriangularSolveThunk(ThunkInfo thunk_info,
const TriangularSolveOptions& options,
const BufferAllocation::Slice& a_buffer,
const BufferAllocation::Slice& b_buffer,
PrimitiveType type, int64 batch_size, int64 m, int64 n,
int64 a_batch_stride, int64 b_batch_stride,
const HloInstruction* hlo);
int64 a_batch_stride, int64 b_batch_stride);
TriangularSolveThunk(const TriangularSolveThunk&) = delete;
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;

Some files were not shown because too many files have changed in this diff Show More