Multiple blocks of code in StreamExecutor can be simplified using the
TF_ASSIGN_OR_RETURN macro, and StatusOr class is already defined in StreamExecutor. PiperOrigin-RevId: 249491049
This commit is contained in:
parent
02cd8bdcfa
commit
259b05ecad
@ -187,28 +187,4 @@ class StatusAdaptorForMacros {
|
||||
.with_log_stack_trace() \
|
||||
.add_ret_check_failure(#condition)
|
||||
|
||||
#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
|
||||
TF_ASSERT_OK_AND_ASSIGN_IMPL( \
|
||||
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
|
||||
rexpr);
|
||||
|
||||
#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
|
||||
auto statusor = (rexpr); \
|
||||
ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
|
||||
lhs = std::move(statusor.ValueOrDie())
|
||||
|
||||
#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
|
||||
#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
|
||||
|
||||
#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \
|
||||
TF_ASSIGN_OR_RETURN_IMPL( \
|
||||
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
|
||||
|
||||
#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
|
||||
auto statusor = (rexpr); \
|
||||
if (TF_PREDICT_FALSE(!statusor.ok())) { \
|
||||
return statusor.status(); \
|
||||
} \
|
||||
lhs = std::move(statusor.ValueOrDie())
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/subprocess.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
@ -120,11 +121,9 @@ port::StatusOr<absl::Span<const uint8>> CompilePtxOrGetCached(
|
||||
compilation_options.ToTuple()};
|
||||
auto it = ptx_cache.find(cache_key);
|
||||
if (it == ptx_cache.end()) {
|
||||
auto compiled_or = CompilePtx(device_ordinal, ptx, compilation_options);
|
||||
TF_RETURN_IF_ERROR(compiled_or.status());
|
||||
std::vector<uint8> compiled = std::move(compiled_or.ValueOrDie());
|
||||
it =
|
||||
ptx_cache.emplace(cache_key, std::move(compiled_or.ValueOrDie())).first;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<uint8> compiled,
|
||||
CompilePtx(device_ordinal, ptx, compilation_options));
|
||||
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
|
||||
}
|
||||
|
||||
CHECK(it != ptx_cache.end());
|
||||
|
@ -306,6 +306,31 @@ void StatusOr<T>::IgnoreError() const {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
|
||||
#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
|
||||
TF_ASSERT_OK_AND_ASSIGN_IMPL( \
|
||||
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
|
||||
rexpr);
|
||||
|
||||
#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
|
||||
auto statusor = (rexpr); \
|
||||
ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
|
||||
lhs = std::move(statusor.ValueOrDie())
|
||||
|
||||
#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
|
||||
#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
|
||||
|
||||
#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \
|
||||
TF_ASSIGN_OR_RETURN_IMPL( \
|
||||
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
|
||||
|
||||
#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
|
||||
auto statusor = (rexpr); \
|
||||
if (TF_PREDICT_FALSE(!statusor.ok())) { \
|
||||
return statusor.status(); \
|
||||
} \
|
||||
lhs = std::move(statusor.ValueOrDie())
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/error.h"
|
||||
#include "tensorflow/stream_executor/lib/stacktrace.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/rng.h"
|
||||
@ -872,11 +873,9 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
|
||||
|
||||
port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
port::StatusOr<StreamExecutor *> stream_executor_or =
|
||||
GetStreamExecutor(device_ordinal);
|
||||
TF_RETURN_IF_ERROR(stream_executor_or.status());
|
||||
DeviceMemoryBase result =
|
||||
stream_executor_or.ValueOrDie()->AllocateArray<uint8>(size);
|
||||
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
|
||||
GetStreamExecutor(device_ordinal));
|
||||
DeviceMemoryBase result = executor->AllocateArray<uint8>(size);
|
||||
if (size > 0 && result == nullptr) {
|
||||
return tensorflow::errors::ResourceExhausted(absl::StrFormat(
|
||||
"Failed to allocate request for %s (%uB) on device ordinal %d",
|
||||
@ -893,12 +892,11 @@ port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
|
||||
port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
|
||||
DeviceMemoryBase mem) {
|
||||
if (!mem.is_null()) {
|
||||
port::StatusOr<StreamExecutor *> stream_executor_or =
|
||||
GetStreamExecutor(device_ordinal);
|
||||
TF_RETURN_IF_ERROR(stream_executor_or.status());
|
||||
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
|
||||
GetStreamExecutor(device_ordinal));
|
||||
VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
|
||||
mem.opaque(), device_ordinal);
|
||||
stream_executor_or.ValueOrDie()->Deallocate(&mem);
|
||||
executor->Deallocate(&mem);
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user