Skip to content

Commit

Permalink
[bin/export_gpu] fix streaming export onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 8, 2024
1 parent 4c2d2f6 commit 4cc33da
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
from __future__ import print_function

import argparse
import logging
import os
import sys

import torch
import yaml
import logging

import torch.nn.functional as F
import yaml
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
Expand Down Expand Up @@ -169,11 +168,14 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
i_kv_cache = att_cache[i:i + 1]
size = att_cache.size(-1) // 2
kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:])
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
masks,
pos_emb,
att_cache=att_cache[i],
att_cache=kv_cache,
cnn_cache=cnn_cache[i],
)
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
Expand Down Expand Up @@ -1241,8 +1243,8 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
if args.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import (
convert_float_to_float16, )
from onnxmltools.utils.float16_converter import \
convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
sys.exit(1)
Expand Down

0 comments on commit 4cc33da

Please sign in to comment.