[XLA] Implement S8,S16,U16 support for Literal::GetIntegralAsS64

PiperOrigin-RevId: 325314732
Change-Id: Ia89c4153d2a70564f46c880f25112c3b74a44b2d
This commit is contained in:
David Majnemer 2020-08-06 14:43:38 -07:00 committed by TensorFlower Gardener
parent 8296bf5a55
commit a0aee5ed2c
2 changed files with 11 additions and 5 deletions
tensorflow/compiler/xla

View File

@ -1004,14 +1004,20 @@ absl::optional<int64> LiteralBase::GetIntegralAsS64(
switch (shape().element_type()) {
case PRED:
return Get<bool>(multi_index);
case S8:
return Get<int8>(multi_index);
case U8:
return Get<uint8>(multi_index);
case S16:
return Get<int16>(multi_index);
case U16:
return Get<uint16>(multi_index);
case S32:
return Get<int32>(multi_index);
case S64:
return Get<int64>(multi_index);
case U32:
return Get<uint32>(multi_index);
case S64:
return Get<int64>(multi_index);
case U64:
return Get<uint64>(multi_index);
default:

View File

@ -1573,9 +1573,9 @@ class OutputBatchIndexToInputIndex {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
// TODO(george): OK what should happen here?
// seems OK to crash though.
index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_);
auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
TF_RET_CHECK(start_index.has_value());
index_vector_[i] = *start_index;
}
return Status::OK();
}