diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fdd36affc2b..afaae576e88 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -216,6 +216,7 @@ cc_library( deps = [ ":backend_configs_cc", ":buffer_allocations", + ":cudnn_batchnorm_runner", ":gpu_constants", ":gpu_conv_runner", ":gpu_executable", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc index adf6b68096d..6b01151b48a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -109,26 +110,23 @@ DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape, return batch_descs; } -void AssignCommonParams(const HloInstruction* batchnorm, +void AssignCommonParams(const CudnnBatchNormConfig& config, CudnnBatchNormParamsCommon* params, const se::DeviceMemoryBase& operand, - const se::DeviceMemory& scale, float epsilon, - int64 feature_index) { + const se::DeviceMemory& scale) { // The BatchNormTraining HLO outputs a tuple of three elements: output data, // batch mean, and batch variance. We want to make our descriptors based on // the shape of the output data. Batchnorm backward call outputs a tuple of // three elements: grad data, grad offset, and grad scale. We want to make // our descriptors based on the shape of the grad data. - const Shape& shape = batchnorm->shape().IsTuple() - ? batchnorm->shape().tuple_shapes(0) - : batchnorm->shape(); + const Shape& shape = config.output_shape; DnnBatchDescriptors batch_descs = - MakeBatchNormDescriptors(shape, feature_index); + MakeBatchNormDescriptors(shape, config.feature_index); params->operand_desc = batch_descs.input_desc; params->scale_offset_desc = batch_descs.scale_offset_desc; params->operand = operand; params->scale = scale; - params->epsilon = epsilon; + params->epsilon = config.epsilon; } template @@ -211,22 +209,33 @@ void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params, } // namespace +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr, + float epsilon, + int64 feature_index) { + CudnnBatchNormConfig config; + + config.output_shape = instr->shape().IsTuple() + ? instr->shape().tuple_shapes(0) + : instr->shape(); + config.output_type = config.output_shape.element_type(); + config.epsilon = epsilon; + config.feature_index = feature_index; + return config; +} + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory variance, se::Stream* stream) { CudnnBatchNormForwardInferenceParams inference_params; - AssignCommonParams(batchnorm, &inference_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &inference_params.common, operand, scale); inference_params.offset = offset; inference_params.mean = mean; inference_params.variance = variance; inference_params.output = output; - PrimitiveType output_primitive_type = batchnorm->shape().element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); @@ -235,29 +244,27 @@ Status RunCudnnBatchNormForwardInference( RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward inference", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory offset, se::Stream* stream) { CudnnBatchNormForwardTrainingParams forward_params; - AssignCommonParams(batchnorm, &forward_params.common, operand, scale, epsilon, - feature_index); + AssignCommonParams(config, &forward_params.common, operand, scale); forward_params.offset = offset; forward_params.output_data = output_data; forward_params.output_mean = output_mean; forward_params.output_inv_stddev = output_inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); @@ -266,22 +273,23 @@ Status RunCudnnBatchNormForwardTraining( RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward training", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream) { + se::Stream* stream) { CudnnBatchNormBackwardParams backward_params; - AssignCommonParams(batchnorm, &backward_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &backward_params.common, operand, scale); backward_params.output_grad_data = output_grad_data; backward_params.grad_output = grad_output; backward_params.output_grad_scale = output_grad_scale; @@ -289,9 +297,7 @@ Status RunCudnnBatchNormBackward( backward_params.mean = mean; backward_params.inv_stddev = inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; @@ -299,8 +305,10 @@ Status RunCudnnBatchNormBackward( RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm backward", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h index 9a630d013f7..b0791b01868 100755 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h @@ -28,27 +28,36 @@ limitations under the License. namespace xla { namespace gpu { +struct CudnnBatchNormConfig { + Shape output_shape; + PrimitiveType output_type; + float epsilon; + int64 feature_index; +}; + +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction *instr, + float epsilon, + int64 feature_index); + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory variance, se::Stream *stream); Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory offset, se::Stream *stream); Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream); + se::Stream *stream); } // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index e91b2c4d0d2..dae490e0d18 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -31,90 +31,21 @@ namespace gpu { namespace dnn = se::dnn; -namespace { -void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { - // All input and output statistics variables must be F32. Also, the last - // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, - // and CudnnBatchNormBackward is the feature_index which must be S64. - // The allowed types for non-statistics variables are as follows: - // CudnnBatchNormForwardInference: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormForwardTraining: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormBackward: - // operand[0]: {half, float} - // operand[4]: {half, float} - // out[0]: {half, float} - // Note non-statistics inputs and outputs mentioned above should be of the - // same type. - - // Check Inputs. - int64 num_operands = hlo->operand_count(); - PrimitiveType operand_primitive_type = - hlo->operand(0)->shape().element_type(); - CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) - << "Not yet implemented"; - - for (int i = 1; i < num_operands - 2; i++) { - if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && - i == 4) { - // The first operand to batchnorm grad is the input and the 4th operand is - // the grad_output, both of which can be Eigen::half. - CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - continue; - } - CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) - << "Not yet implemented"; - } - - // The last operand is the feature index which must be int64. - CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) - << "Not yet implemented"; - - // Check Outputs. - if (hlo->shape().IsTuple()) { - CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), - operand_primitive_type) - << "Invalid datatype"; - - for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { - CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) - << "Not yet implemented"; - } - } else { - CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - } -} -} // namespace - CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, - const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, + const BufferAllocation::Slice& variance, const BufferAllocation::Slice& output) : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), mean_(mean), variance_(variance), - epsilon_(epsilon), - feature_index_(feature_index), - output_(output) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), - kCudnnBatchNormForwardInferenceCallTarget); - CHECK( - LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_(output) {} Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -131,8 +62,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(variance_)); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference( - hlo_instruction_, operand, output_base, scale, offset, mean, variance, - epsilon_, feature_index_, &stream)); + config_, operand, output_base, scale, offset, mean, variance, &stream)); if (!stream.ok()) { return InternalError("BatchNormalizationForward call failed."); @@ -141,32 +71,22 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( } CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), - epsilon_(epsilon), - feature_index_(feature_index), output_data_(output_data), output_mean_(output_mean), output_inv_stddev_(output_inv_stddev), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -185,10 +105,10 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( - hlo_instruction_, operand, output_data, output_mean, output_inv_stddev, + config_, operand, output_data, output_mean, output_inv_stddev, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), - epsilon_, feature_index_, &stream)); + &stream)); // Write the output tuple. const int kNumOutputs = 3; @@ -207,37 +127,26 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( } CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, - const BufferAllocation::Slice& grad_output, float epsilon, - int64 feature_index, const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& grad_output, + const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), mean_(mean), inv_stddev_(inv_stddev), grad_output_(grad_output), - epsilon_(epsilon), - feature_index_(feature_index), output_grad_data_(output_grad_data), output_grad_scale_(output_grad_scale), output_grad_offset_(output_grad_offset), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(4)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -256,12 +165,12 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); se::Stream* stream = params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( - hlo_instruction_, operand, output_grad_data, grad_output, - output_grad_scale, output_grad_offset, + config_, operand, output_grad_data, grad_output, output_grad_scale, + output_grad_offset, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), - epsilon_, feature_index_, stream)); + stream)); // Write the output tuple. const int kNumOutputs = 3; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index bb46017b8fb..79b915b59a7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,12 +48,12 @@ namespace gpu { class CudnnBatchNormForwardInferenceThunk : public Thunk { public: CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& variance, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output); CudnnBatchNormForwardInferenceThunk( @@ -63,23 +64,22 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; BufferAllocation::Slice mean_; BufferAllocation::Slice variance_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_; }; class CudnnBatchNormForwardTrainingThunk : public Thunk { public: CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, - const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, + const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, @@ -93,12 +93,10 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_data_; BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_inv_stddev_; @@ -108,12 +106,12 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { class CudnnBatchNormBackwardThunk : public Thunk { public: CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, const BufferAllocation::Slice& grad_output, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, @@ -126,14 +124,12 @@ class CudnnBatchNormBackwardThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice mean_; BufferAllocation::Slice inv_stddev_; BufferAllocation::Slice grad_output_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_offset_; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 48923d342e0..f20b8d9ccf3 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -38,6 +39,65 @@ limitations under the License. namespace xla { namespace gpu { +namespace { +void CheckBatchNormInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { + // All input and output statistics variables must be F32. Also, the last + // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, + // and CudnnBatchNormBackward is the feature_index which must be S64. + // The allowed types for non-statistics variables are as follows: + // CudnnBatchNormForwardInference: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormForwardTraining: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormBackward: + // operand[0]: {half, float} + // operand[4]: {half, float} + // out[0]: {half, float} + // Note non-statistics inputs and outputs mentioned above should be of the + // same type. + + // Check Inputs. + int64 num_operands = hlo->operand_count(); + PrimitiveType operand_primitive_type = + hlo->operand(0)->shape().element_type(); + CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) + << "Not yet implemented"; + + for (int i = 1; i < num_operands - 2; i++) { + if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + i == 4) { + // The first operand to batchnorm grad is the input and the 4th operand is + // the grad_output, both of which can be Eigen::half. + CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + continue; + } + CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) + << "Not yet implemented"; + } + + // The last operand is the feature index which must be int64. + CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) + << "Not yet implemented"; + + // Check Outputs. + if (hlo->shape().IsTuple()) { + CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), + operand_primitive_type) + << "Invalid datatype"; + + for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { + CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) + << "Not yet implemented"; + } + } else { + CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + } +} +} // namespace std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); return absl::make_unique( @@ -154,16 +214,20 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { CHECK(feature_index->IsConstant()); int64 feature_index_value = feature_index->literal().Get({}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*mean=*/GetAllocationSlice(*custom_call->operand(3)), /*variance=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -183,14 +247,14 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_data = GetAllocationSlice(*custom_call, {0}); auto output_mean = GetAllocationSlice(*custom_call, {1}); auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_data=*/output_data, /*output_mean=*/output_mean, /*output_inv_stddev=*/output_inv_stddev, @@ -212,15 +276,22 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_data = GetAllocationSlice(*custom_call, {0}); auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(4)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*mean=*/GetAllocationSlice(*custom_call->operand(2)), /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, /*output_grad_offset=*/output_grad_offset,