Merge pull request #41832 from lissyx:issue41630
PiperOrigin-RevId: 325454474 Change-Id: I5fb9f4cfa4d9836056f3ba14c12f2a04d1a09a55
This commit is contained in:
commit
0fb80172ba
@ -500,6 +500,9 @@ struct CudnnRnnModelShapes {
|
|||||||
int max_seq_length;
|
int max_seq_length;
|
||||||
int batch_size;
|
int batch_size;
|
||||||
int cell_num_units = 0;
|
int cell_num_units = 0;
|
||||||
|
// If you add new field to this structure, please take care of
|
||||||
|
// updating IsCompatibleWith() below as well as the hash function in
|
||||||
|
// CudnnRnnConfigHasher.
|
||||||
TensorShape input_shape;
|
TensorShape input_shape;
|
||||||
TensorShape output_shape;
|
TensorShape output_shape;
|
||||||
TensorShape hidden_state_shape;
|
TensorShape hidden_state_shape;
|
||||||
@ -508,7 +511,8 @@ struct CudnnRnnModelShapes {
|
|||||||
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
|
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
|
||||||
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
|
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
|
||||||
num_units == rhs.num_units && dir_count == rhs.dir_count &&
|
num_units == rhs.num_units && dir_count == rhs.dir_count &&
|
||||||
cell_num_units == rhs.cell_num_units;
|
cell_num_units == rhs.cell_num_units &&
|
||||||
|
max_seq_length == rhs.max_seq_length;
|
||||||
}
|
}
|
||||||
string DebugString() const {
|
string DebugString() const {
|
||||||
return strings::Printf(
|
return strings::Printf(
|
||||||
@ -530,7 +534,7 @@ struct CudnnRnnConfigHasher {
|
|||||||
|
|
||||||
uint64 hash =
|
uint64 hash =
|
||||||
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
|
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
|
||||||
shapes.dir_count, shapes.batch_size});
|
shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
|
||||||
if (algo_desc.has_value()) {
|
if (algo_desc.has_value()) {
|
||||||
hash = Hash64Combine(hash, algo_desc->hash());
|
hash = Hash64Combine(hash, algo_desc->hash());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user