Skip to content

Commit

Permalink
refactor cache behaviour in training mode (reduce compute cost and me…
Browse files Browse the repository at this point in the history
…mory) (#2473)
  • Loading branch information
Mddct authored Apr 14, 2024
1 parent 8bad166 commit d0636f1
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,15 @@ def forward(
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
if cache.size(0) > 0 and not self.training:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache

# for multi query or multi group attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -370,7 +370,7 @@ def forward(
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
if cache.size(0) > 0 and not self.training:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
Expand All @@ -379,7 +379,7 @@ def forward(

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -472,7 +472,7 @@ def forward(

else:
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -563,13 +563,13 @@ def forward(
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
if cache.size(0) > 0 and not self.training:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
Expand Down Expand Up @@ -664,13 +664,13 @@ def forward(
q = llama_apply_rotary_emb(q, pos_emb)
k = llama_apply_rotary_emb(k, pos_emb)
# see above
if cache.size(0) > 0:
if cache.size(0) > 0 and not self.training:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache

if self.h_kv != self.h:
k = torch.repeat_interleave(
Expand Down

0 comments on commit d0636f1

Please sign in to comment.