141 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			141 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
==============================================================================*/
 | 
						|
 | 
						|
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
 | 
						|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
 | 
						|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
						|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 | 
						|
 | 
						|
namespace xla {
 | 
						|
 | 
						|
bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
 | 
						|
                                          int64 operand_index) const {
 | 
						|
  switch (hlo.opcode()) {
 | 
						|
    case HloOpcode::kCall:
 | 
						|
    case HloOpcode::kConditional:
 | 
						|
    case HloOpcode::kCustomCall:
 | 
						|
    case HloOpcode::kDomain:
 | 
						|
    case HloOpcode::kGetTupleElement:
 | 
						|
    case HloOpcode::kTuple:
 | 
						|
    case HloOpcode::kWhile:
 | 
						|
      return true;
 | 
						|
    case HloOpcode::kConvert:
 | 
						|
      CHECK_EQ(operand_index, 0);
 | 
						|
      return hlo.operand(0)->shape().element_type() == BF16;
 | 
						|
    default:
 | 
						|
      break;
 | 
						|
  }
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
 | 
						|
  switch (hlo.opcode()) {
 | 
						|
    case HloOpcode::kCall:
 | 
						|
    case HloOpcode::kConditional:
 | 
						|
    case HloOpcode::kCustomCall:
 | 
						|
    case HloOpcode::kDomain:
 | 
						|
    case HloOpcode::kGetTupleElement:
 | 
						|
    case HloOpcode::kTuple:
 | 
						|
    case HloOpcode::kWhile:
 | 
						|
      return true;
 | 
						|
    case HloOpcode::kConvert:
 | 
						|
      return hlo.shape().element_type() == BF16;
 | 
						|
    default:
 | 
						|
      break;
 | 
						|
  }
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
 | 
						|
  switch (hlo.opcode()) {
 | 
						|
    case HloOpcode::kCall:
 | 
						|
    case HloOpcode::kConditional:
 | 
						|
    case HloOpcode::kConvert:
 | 
						|
    case HloOpcode::kCustomCall:
 | 
						|
    case HloOpcode::kGetTupleElement:
 | 
						|
    case HloOpcode::kTuple:
 | 
						|
    case HloOpcode::kWhile:
 | 
						|
      return true;
 | 
						|
    default:
 | 
						|
      break;
 | 
						|
  }
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
/* static */
 | 
						|
bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
 | 
						|
    const HloInstruction& hlo, int64 operand_index) {
 | 
						|
  switch (hlo.opcode()) {
 | 
						|
    case HloOpcode::kAbs:
 | 
						|
    case HloOpcode::kAllGather:
 | 
						|
    case HloOpcode::kAllToAll:
 | 
						|
    case HloOpcode::kBroadcast:
 | 
						|
    case HloOpcode::kClamp:
 | 
						|
    case HloOpcode::kCollectivePermute:
 | 
						|
    case HloOpcode::kConcatenate:
 | 
						|
    case HloOpcode::kConvert:
 | 
						|
    case HloOpcode::kCopy:
 | 
						|
    case HloOpcode::kDomain:
 | 
						|
    case HloOpcode::kGetTupleElement:
 | 
						|
    case HloOpcode::kMaximum:
 | 
						|
    case HloOpcode::kMinimum:
 | 
						|
    case HloOpcode::kPad:
 | 
						|
    case HloOpcode::kReshape:
 | 
						|
    case HloOpcode::kReverse:
 | 
						|
    case HloOpcode::kSlice:
 | 
						|
    case HloOpcode::kSort:
 | 
						|
    case HloOpcode::kTranspose:
 | 
						|
    case HloOpcode::kTuple:
 | 
						|
      return true;
 | 
						|
    case HloOpcode::kBitcast:
 | 
						|
      return hlo.shape().element_type() ==
 | 
						|
             hlo.operand(0)->shape().element_type();
 | 
						|
    case HloOpcode::kDynamicSlice:
 | 
						|
      return operand_index == 0;
 | 
						|
    case HloOpcode::kDynamicUpdateSlice:
 | 
						|
      return operand_index == 0 || operand_index == 1;
 | 
						|
    case HloOpcode::kGather:
 | 
						|
      return operand_index == 0;
 | 
						|
    case HloOpcode::kSelect:
 | 
						|
    case HloOpcode::kTupleSelect:
 | 
						|
      return operand_index == 1 || operand_index == 2;
 | 
						|
    case HloOpcode::kReduce:
 | 
						|
    case HloOpcode::kReduceWindow: {
 | 
						|
      HloComputation* reduce_comp = hlo.called_computations()[0];
 | 
						|
      for (HloInstruction* inst : reduce_comp->instructions()) {
 | 
						|
        if (inst->opcode() == HloOpcode::kParameter) {
 | 
						|
          continue;
 | 
						|
        }
 | 
						|
        for (int64 i = 0; i < inst->operand_count(); ++i) {
 | 
						|
          if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) {
 | 
						|
            return false;
 | 
						|
          }
 | 
						|
        }
 | 
						|
      }
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
    default:
 | 
						|
      break;
 | 
						|
  }
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
 | 
						|
    const HloInstruction& hlo, int64 operand_index) const {
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
}  // namespace xla
 |