Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 8, 2024
1 parent 4cc33da commit 8b90594
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
i_kv_cache = att_cache[i:i + 1]
i_kv_cache = att_cache[i]
size = att_cache.size(-1) // 2
kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:])
xs, _, new_att_cache, new_cnn_cache = layer(
xs, _, new_kv_cache, new_cnn_cache = layer(
xs,
masks,
pos_emb,
Expand All @@ -180,6 +180,7 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
)
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
new_att_cache = torch.cat(new_kv_cache, dim=-1)
r_att_cache.append(
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
if not self.transformer:
Expand Down

0 comments on commit 8b90594

Please sign in to comment.