[XLA] Implement S8,S16,U16 support for Literal::GetIntegralAsS64
PiperOrigin-RevId: 325314732 Change-Id: Ia89c4153d2a70564f46c880f25112c3b74a44b2d
This commit is contained in:
parent
8296bf5a55
commit
a0aee5ed2c
tensorflow/compiler/xla
@ -1004,14 +1004,20 @@ absl::optional<int64> LiteralBase::GetIntegralAsS64(
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
return Get<bool>(multi_index);
|
||||
case S8:
|
||||
return Get<int8>(multi_index);
|
||||
case U8:
|
||||
return Get<uint8>(multi_index);
|
||||
case S16:
|
||||
return Get<int16>(multi_index);
|
||||
case U16:
|
||||
return Get<uint16>(multi_index);
|
||||
case S32:
|
||||
return Get<int32>(multi_index);
|
||||
case S64:
|
||||
return Get<int64>(multi_index);
|
||||
case U32:
|
||||
return Get<uint32>(multi_index);
|
||||
case S64:
|
||||
return Get<int64>(multi_index);
|
||||
case U64:
|
||||
return Get<uint64>(multi_index);
|
||||
default:
|
||||
|
@ -1573,9 +1573,9 @@ class OutputBatchIndexToInputIndex {
|
||||
int64 index_vector_dim = dim_numbers_.index_vector_dim();
|
||||
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
|
||||
index_vector_index_[index_vector_dim] = i;
|
||||
// TODO(george): OK what should happen here?
|
||||
// seems OK to crash though.
|
||||
index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_);
|
||||
auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_);
|
||||
TF_RET_CHECK(start_index.has_value());
|
||||
index_vector_[i] = *start_index;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user