[NFC] Eliminate references to HLO insts from CudnnBatchNorm Thunks.

- Eliminate HLO inst references from these thunks and replace them with a
  CudnnBatchNorm configuration object.
- Also move some of the verification that was happening the Thunk constructors to
  the thunk emitter

PiperOrigin-RevId: 335652233
Change-Id: I13adbac58ebaf6a45aae9d4b9f3201c3fceb8bb8
This commit is contained in:
Rahul Joshi 2020-10-06 09:08:55 -07:00 committed by TensorFlower Gardener
parent ec66071ad7
commit e014c2f458
6 changed files with 172 additions and 178 deletions

View File

@ -216,6 +216,7 @@ cc_library(
deps = [
":backend_configs_cc",
":buffer_allocations",
":cudnn_batchnorm_runner",
":gpu_constants",
":gpu_conv_runner",
":gpu_executable",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
@ -109,26 +110,23 @@ DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape,
return batch_descs;
}
void AssignCommonParams(const HloInstruction* batchnorm,
void AssignCommonParams(const CudnnBatchNormConfig& config,
CudnnBatchNormParamsCommon* params,
const se::DeviceMemoryBase& operand,
const se::DeviceMemory<float>& scale, float epsilon,
int64 feature_index) {
const se::DeviceMemory<float>& scale) {
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
// batch mean, and batch variance. We want to make our descriptors based on
// the shape of the output data. Batchnorm backward call outputs a tuple of
// three elements: grad data, grad offset, and grad scale. We want to make
// our descriptors based on the shape of the grad data.
const Shape& shape = batchnorm->shape().IsTuple()
? batchnorm->shape().tuple_shapes(0)
: batchnorm->shape();
const Shape& shape = config.output_shape;
DnnBatchDescriptors batch_descs =
MakeBatchNormDescriptors(shape, feature_index);
MakeBatchNormDescriptors(shape, config.feature_index);
params->operand_desc = batch_descs.input_desc;
params->scale_offset_desc = batch_descs.scale_offset_desc;
params->operand = operand;
params->scale = scale;
params->epsilon = epsilon;
params->epsilon = config.epsilon;
}
template <typename ElemType>
@ -211,22 +209,33 @@ void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
} // namespace
CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr,
float epsilon,
int64 feature_index) {
CudnnBatchNormConfig config;
config.output_shape = instr->shape().IsTuple()
? instr->shape().tuple_shapes(0)
: instr->shape();
config.output_type = config.output_shape.element_type();
config.epsilon = epsilon;
config.feature_index = feature_index;
return config;
}
Status RunCudnnBatchNormForwardInference(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
se::DeviceMemory<float> variance, float epsilon, int64 feature_index,
se::Stream* stream) {
se::DeviceMemory<float> variance, se::Stream* stream) {
CudnnBatchNormForwardInferenceParams inference_params;
AssignCommonParams(batchnorm, &inference_params.common, operand, scale,
epsilon, feature_index);
AssignCommonParams(config, &inference_params.common, operand, scale);
inference_params.offset = offset;
inference_params.mean = mean;
inference_params.variance = variance;
inference_params.output = output;
PrimitiveType output_primitive_type = batchnorm->shape().element_type();
switch (output_primitive_type) {
switch (config.output_type) {
case F16:
RunCudnnBatchNormForwardInferenceImpl<Eigen::half>(&inference_params,
stream);
@ -235,29 +244,27 @@ Status RunCudnnBatchNormForwardInference(
RunCudnnBatchNormForwardInferenceImpl<float>(&inference_params, stream);
break;
default:
return Unimplemented("Primitive type not implemented for \"%s\" ",
batchnorm->ToString());
return Unimplemented(
"Primitive type %s not implemented for batchnorm forward inference",
primitive_util::LowercasePrimitiveTypeName(config.output_type)
.c_str());
}
return Status::OK();
}
Status RunCudnnBatchNormForwardTraining(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
se::DeviceMemory<float> offset, float epsilon, int64 feature_index,
se::Stream* stream) {
se::DeviceMemory<float> offset, se::Stream* stream) {
CudnnBatchNormForwardTrainingParams forward_params;
AssignCommonParams(batchnorm, &forward_params.common, operand, scale, epsilon,
feature_index);
AssignCommonParams(config, &forward_params.common, operand, scale);
forward_params.offset = offset;
forward_params.output_data = output_data;
forward_params.output_mean = output_mean;
forward_params.output_inv_stddev = output_inv_stddev;
PrimitiveType output_primitive_type =
batchnorm->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
switch (config.output_type) {
case F16:
RunCudnnBatchNormForwardTrainingImpl<Eigen::half>(&forward_params,
stream);
@ -266,22 +273,23 @@ Status RunCudnnBatchNormForwardTraining(
RunCudnnBatchNormForwardTrainingImpl<float>(&forward_params, stream);
break;
default:
return Unimplemented("Primitive type not implemented for \"%s\" ",
batchnorm->ToString());
return Unimplemented(
"Primitive type %s not implemented for batchnorm forward training",
primitive_util::LowercasePrimitiveTypeName(config.output_type)
.c_str());
}
return Status::OK();
}
Status RunCudnnBatchNormBackward(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
se::DeviceMemory<float> output_grad_scale,
se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
float epsilon, int64 feature_index, se::Stream* stream) {
se::Stream* stream) {
CudnnBatchNormBackwardParams backward_params;
AssignCommonParams(batchnorm, &backward_params.common, operand, scale,
epsilon, feature_index);
AssignCommonParams(config, &backward_params.common, operand, scale);
backward_params.output_grad_data = output_grad_data;
backward_params.grad_output = grad_output;
backward_params.output_grad_scale = output_grad_scale;
@ -289,9 +297,7 @@ Status RunCudnnBatchNormBackward(
backward_params.mean = mean;
backward_params.inv_stddev = inv_stddev;
PrimitiveType output_primitive_type =
batchnorm->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
switch (config.output_type) {
case F16:
RunCudnnBatchNormBackwardImpl<Eigen::half>(&backward_params, stream);
break;
@ -299,8 +305,10 @@ Status RunCudnnBatchNormBackward(
RunCudnnBatchNormBackwardImpl<float>(&backward_params, stream);
break;
default:
return Unimplemented("Primitive type not implemented for \"%s\" ",
batchnorm->ToString());
return Unimplemented(
"Primitive type %s not implemented for batchnorm backward",
primitive_util::LowercasePrimitiveTypeName(config.output_type)
.c_str());
}
return Status::OK();
}

View File

@ -28,27 +28,36 @@ limitations under the License.
namespace xla {
namespace gpu {
struct CudnnBatchNormConfig {
Shape output_shape;
PrimitiveType output_type;
float epsilon;
int64 feature_index;
};
CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction *instr,
float epsilon,
int64 feature_index);
Status RunCudnnBatchNormForwardInference(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
se::DeviceMemory<float> variance, float epsilon, int64 feature_index,
se::Stream* stream);
se::DeviceMemory<float> variance, se::Stream *stream);
Status RunCudnnBatchNormForwardTraining(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
se::DeviceMemory<float> offset, float epsilon, int64 feature_index,
se::Stream* stream);
se::DeviceMemory<float> offset, se::Stream *stream);
Status RunCudnnBatchNormBackward(
const HloInstruction* batchnorm, se::DeviceMemoryBase operand,
const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand,
se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
se::DeviceMemory<float> output_grad_scale,
se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
float epsilon, int64 feature_index, se::Stream* stream);
se::Stream *stream);
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_RUNNER_H_

View File

@ -31,90 +31,21 @@ namespace gpu {
namespace dnn = se::dnn;
namespace {
void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
// All input and output statistics variables must be F32. Also, the last
// operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining,
// and CudnnBatchNormBackward is the feature_index which must be S64.
// The allowed types for non-statistics variables are as follows:
// CudnnBatchNormForwardInference:
// operand[0]: {half, float}
// out[0]: {half, float}
// CudnnBatchNormForwardTraining:
// operand[0]: {half, float}
// out[0]: {half, float}
// CudnnBatchNormBackward:
// operand[0]: {half, float}
// operand[4]: {half, float}
// out[0]: {half, float}
// Note non-statistics inputs and outputs mentioned above should be of the
// same type.
// Check Inputs.
int64 num_operands = hlo->operand_count();
PrimitiveType operand_primitive_type =
hlo->operand(0)->shape().element_type();
CHECK(operand_primitive_type == F16 || operand_primitive_type == F32)
<< "Not yet implemented";
for (int i = 1; i < num_operands - 2; i++) {
if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget &&
i == 4) {
// The first operand to batchnorm grad is the input and the 4th operand is
// the grad_output, both of which can be Eigen::half.
CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type)
<< "Invalid datatype";
continue;
}
CHECK_EQ(hlo->operand(i)->shape().element_type(), F32)
<< "Not yet implemented";
}
// The last operand is the feature index which must be int64.
CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64)
<< "Not yet implemented";
// Check Outputs.
if (hlo->shape().IsTuple()) {
CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(),
operand_primitive_type)
<< "Invalid datatype";
for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) {
CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32)
<< "Not yet implemented";
}
} else {
CHECK_EQ(hlo->shape().element_type(), operand_primitive_type)
<< "Invalid datatype";
}
}
} // namespace
CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
const BufferAllocation::Slice& variance,
const BufferAllocation::Slice& output)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
config_(std::move(config)),
operand_(operand),
scale_(scale),
offset_(offset),
mean_(mean),
variance_(variance),
epsilon_(epsilon),
feature_index_(feature_index),
output_(output) {
const auto* hlo = hlo_instruction_;
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(),
kCudnnBatchNormForwardInferenceCallTarget);
CHECK(
LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape()));
CheckInputOutputPrimitivetypeAreValid(hlo);
}
output_(output) {}
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
const ExecuteParams& params) {
@ -131,8 +62,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(variance_));
auto& stream = *params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference(
hlo_instruction_, operand, output_base, scale, offset, mean, variance,
epsilon_, feature_index_, &stream));
config_, operand, output_base, scale, offset, mean, variance, &stream));
if (!stream.ok()) {
return InternalError("BatchNormalizationForward call failed.");
@ -141,32 +71,22 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
}
CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
config_(std::move(config)),
operand_(operand),
scale_(scale),
offset_(offset),
epsilon_(epsilon),
feature_index_(feature_index),
output_data_(output_data),
output_mean_(output_mean),
output_inv_stddev_(output_inv_stddev),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction_;
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
hlo->operand(0)->shape()));
CheckInputOutputPrimitivetypeAreValid(hlo);
}
output_tuple_(output_tuple) {}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
const ExecuteParams& params) {
@ -185,10 +105,10 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto& stream = *params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
hlo_instruction_, operand, output_data, output_mean, output_inv_stddev,
config_, operand, output_data, output_mean, output_inv_stddev,
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
epsilon_, feature_index_, &stream));
&stream));
// Write the output tuple.
const int kNumOutputs = 3;
@ -207,37 +127,26 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
}
CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output, float epsilon,
int64 feature_index, const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
config_(std::move(config)),
operand_(operand),
scale_(scale),
mean_(mean),
inv_stddev_(inv_stddev),
grad_output_(grad_output),
epsilon_(epsilon),
feature_index_(feature_index),
output_grad_data_(output_grad_data),
output_grad_scale_(output_grad_scale),
output_grad_offset_(output_grad_offset),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction_;
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
hlo->operand(0)->shape()));
CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0),
hlo->operand(4)->shape()));
CheckInputOutputPrimitivetypeAreValid(hlo);
}
output_tuple_(output_tuple) {}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
const ExecuteParams& params) {
@ -256,12 +165,12 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::Stream* stream = params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
hlo_instruction_, operand, output_grad_data, grad_output,
output_grad_scale, output_grad_offset,
config_, operand, output_grad_data, grad_output, output_grad_scale,
output_grad_offset,
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
epsilon_, feature_index_, stream));
stream));
// Write the output tuple.
const int kNumOutputs = 3;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -47,12 +48,12 @@ namespace gpu {
class CudnnBatchNormForwardInferenceThunk : public Thunk {
public:
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output);
CudnnBatchNormForwardInferenceThunk(
@ -63,23 +64,22 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
CudnnBatchNormConfig config_;
BufferAllocation::Slice operand_;
BufferAllocation::Slice scale_;
BufferAllocation::Slice offset_;
BufferAllocation::Slice mean_;
BufferAllocation::Slice variance_;
float epsilon_;
int64 feature_index_;
BufferAllocation::Slice output_;
};
class CudnnBatchNormForwardTrainingThunk : public Thunk {
public:
CudnnBatchNormForwardTrainingThunk(
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
@ -93,12 +93,10 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
CudnnBatchNormConfig config_;
BufferAllocation::Slice operand_;
BufferAllocation::Slice scale_;
BufferAllocation::Slice offset_;
float epsilon_;
int64 feature_index_;
BufferAllocation::Slice output_data_;
BufferAllocation::Slice output_mean_;
BufferAllocation::Slice output_inv_stddev_;
@ -108,12 +106,12 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
class CudnnBatchNormBackwardThunk : public Thunk {
public:
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
@ -126,14 +124,12 @@ class CudnnBatchNormBackwardThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
CudnnBatchNormConfig config_;
BufferAllocation::Slice operand_;
BufferAllocation::Slice scale_;
BufferAllocation::Slice mean_;
BufferAllocation::Slice inv_stddev_;
BufferAllocation::Slice grad_output_;
float epsilon_;
int64 feature_index_;
BufferAllocation::Slice output_grad_data_;
BufferAllocation::Slice output_grad_scale_;
BufferAllocation::Slice output_grad_offset_;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
@ -38,6 +39,65 @@ limitations under the License.
namespace xla {
namespace gpu {
namespace {
void CheckBatchNormInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
// All input and output statistics variables must be F32. Also, the last
// operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining,
// and CudnnBatchNormBackward is the feature_index which must be S64.
// The allowed types for non-statistics variables are as follows:
// CudnnBatchNormForwardInference:
// operand[0]: {half, float}
// out[0]: {half, float}
// CudnnBatchNormForwardTraining:
// operand[0]: {half, float}
// out[0]: {half, float}
// CudnnBatchNormBackward:
// operand[0]: {half, float}
// operand[4]: {half, float}
// out[0]: {half, float}
// Note non-statistics inputs and outputs mentioned above should be of the
// same type.
// Check Inputs.
int64 num_operands = hlo->operand_count();
PrimitiveType operand_primitive_type =
hlo->operand(0)->shape().element_type();
CHECK(operand_primitive_type == F16 || operand_primitive_type == F32)
<< "Not yet implemented";
for (int i = 1; i < num_operands - 2; i++) {
if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget &&
i == 4) {
// The first operand to batchnorm grad is the input and the 4th operand is
// the grad_output, both of which can be Eigen::half.
CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type)
<< "Invalid datatype";
continue;
}
CHECK_EQ(hlo->operand(i)->shape().element_type(), F32)
<< "Not yet implemented";
}
// The last operand is the feature index which must be int64.
CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64)
<< "Not yet implemented";
// Check Outputs.
if (hlo->shape().IsTuple()) {
CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(),
operand_primitive_type)
<< "Invalid datatype";
for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) {
CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32)
<< "Not yet implemented";
}
} else {
CHECK_EQ(hlo->shape().element_type(), operand_primitive_type)
<< "Invalid datatype";
}
}
} // namespace
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
return absl::make_unique<FftThunk>(
@ -154,16 +214,20 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
CHECK(feature_index->IsConstant());
int64 feature_index_value = feature_index->literal().Get<int64>({});
CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3);
CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0),
custom_call->operand(0)->shape()));
CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call);
CudnnBatchNormConfig config = GetCudnnBatchNormConfig(
custom_call, epsilon_value, feature_index_value);
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
context_->GetThunkInfo(custom_call),
context_->GetThunkInfo(custom_call), std::move(config),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
/*mean=*/GetAllocationSlice(*custom_call->operand(3)),
/*variance=*/GetAllocationSlice(*custom_call->operand(4)),
/*epsilon=*/epsilon_value,
/*feature_index=*/feature_index_value,
/*output=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -183,14 +247,14 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_data = GetAllocationSlice(*custom_call, {0});
auto output_mean = GetAllocationSlice(*custom_call, {1});
auto output_inv_stddev = GetAllocationSlice(*custom_call, {2});
CudnnBatchNormConfig config = GetCudnnBatchNormConfig(
custom_call, epsilon_value, feature_index_value);
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
context_->GetThunkInfo(custom_call),
context_->GetThunkInfo(custom_call), std::move(config),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
/*epsilon=*/epsilon_value,
/*feature_index=*/feature_index_value,
/*output_data=*/output_data,
/*output_mean=*/output_mean,
/*output_inv_stddev=*/output_inv_stddev,
@ -212,15 +276,22 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_data = GetAllocationSlice(*custom_call, {0});
auto output_grad_scale = GetAllocationSlice(*custom_call, {1});
auto output_grad_offset = GetAllocationSlice(*custom_call, {2});
CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3);
CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0),
custom_call->operand(0)->shape()));
CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0),
custom_call->operand(4)->shape()));
CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call);
CudnnBatchNormConfig config = GetCudnnBatchNormConfig(
custom_call, epsilon_value, feature_index_value);
AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
context_->GetThunkInfo(custom_call),
context_->GetThunkInfo(custom_call), std::move(config),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*mean=*/GetAllocationSlice(*custom_call->operand(2)),
/*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
/*epsilon=*/epsilon_value,
/*feature_index=*/feature_index_value,
/*output_grad_data=*/output_grad_data,
/*output_grad_scale=*/output_grad_scale,
/*output_grad_offset=*/output_grad_offset,