Skip to content

经过测试cosyvoice3 flow的重建效果似乎不如v2 #1870

@trimonster233

Description

@trimonster233

origin.wav

cosyvoice2.wav

cosyvoice3.wav

第一个是GT,转为tokens交给flow+vocoder输出真实音频,实际测试下来cosyvoice3对于喘息有比较明显的破音。

测试代码:

"""
测试 speech token 重建质量
用 CosyVoice3 的 flow + hift 重建音频。

默认模式: 从预计算的 utt2speech_token.pt + utt2embedding.pt 读取。
--extract_tokens 模式: 实时提取 token (speech_tokenizer) 和 embedding (campplus),
  只需 wav.scp 即可运行,utt2spk 可选(有则用同说话人 prompt)。

用法:
  # 重建前 5 条 (读预计算 token + embedding)
  python test_token_reconstruction.py

  # 实时提取 token + embedding,只需 wav.scp
  python test_token_reconstruction.py --extract_tokens

  # 指定 utt / 重建条数
  python test_token_reconstruction.py --utt_id 0_320_5
  python test_token_reconstruction.py --num_utts 10 --extract_tokens
"""

import argparse
import math
import os
import random
import sys
import numpy as np
import onnxruntime
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import whisper
from hyperpyyaml import load_hyperpyyaml

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, '.')

COSYVOICE_MODEL_DIR = '../pretrained_models/cosyvoice2'
DATA_DIR = 'data/test'

# whisper tokenizer 单次最长 30s (对应 3000 帧)
WHISPER_MAX_SEC = 30


def load_wav(wav_path, target_sr):
    speech, sr = torchaudio.load(wav_path)
    if speech.shape[0] > 1:
        speech = speech.mean(dim=0, keepdim=True)
    if sr != target_sr:
        speech = torchaudio.transforms.Resample(sr, target_sr)(speech)
    return speech


def extract_speech_feat(wav_path, feat_extractor):
    """音频 → cosyvoice mel (80-bin, 24kHz), shape [1, T, 80]"""
    speech = load_wav(wav_path, 24000)
    speech_feat = feat_extractor(speech).squeeze(dim=0).transpose(0, 1)
    return speech_feat.unsqueeze(dim=0)


def load_flow_hift(model_dir, device):
    yaml_candidates = [
        os.path.join(model_dir, 'cosyvoice2.yaml'),
        os.path.join(model_dir, 'cosyvoice3.yaml'),
        os.path.join(os.path.dirname(os.path.abspath(__file__)), 'conf', 'cosyvoice2.yaml'),
        os.path.join(os.path.dirname(os.path.abspath(__file__)), 'conf', 'cosyvoice3.yaml'),
    ]
    configs = None
    last_err = None
    used_yaml = None
    for yaml_path in yaml_candidates:
        if not os.path.exists(yaml_path):
            continue
        try:
            with open(yaml_path, 'r') as f:
                configs = load_hyperpyyaml(f, overrides={
                    'llm': None,
                    'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN'),
                })
            used_yaml = yaml_path
            break
        except ImportError as e:
            # Some released yaml files still reference matcha.*.
            # Fallback to local train/conf/cosyvoice*.yaml in that case.
            last_err = e
            continue
    if configs is None:
        if last_err is not None:
            raise last_err
        raise FileNotFoundError(f'找不到可用的 cosyvoice yaml,已尝试: {yaml_candidates}')
    print(f'  config: {used_yaml}')
    flow = configs['flow'].to(device).eval()
    hift = configs['hift'].to(device).eval()
    flow_state = torch.load(os.path.join(model_dir, 'flow.pt'), map_location=device)
    flow.load_state_dict(flow_state, strict=True)
    hift_state = {k.replace('generator.', ''): v for k, v in
                  torch.load(os.path.join(model_dir, 'hift.pt'), map_location=device).items()}
    hift.load_state_dict(hift_state, strict=True)
    return flow, hift, configs


def build_onnx_session(onnx_path):
    """加载 ONNX 模型(tokenizer 或 campplus embedding)"""
    assert os.path.exists(onnx_path), f'找不到模型: {onnx_path}'
    option = onnxruntime.SessionOptions()
    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    option.intra_op_num_threads = 1
    providers = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
    session = onnxruntime.InferenceSession(onnx_path, sess_options=option, providers=providers)
    print(f'  loaded: {onnx_path}')
    return session


def extract_embedding_from_wav(wav_path, session):
    """从音频文件提取 speaker embedding (campplus),返回 List[float]"""
    audio = load_wav(wav_path, 16000)  # [1, N] at 16kHz
    feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
    feat = feat - feat.mean(dim=0, keepdim=True)
    embedding = session.run(
        None, {session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
    )[0].flatten().tolist()
    return embedding


def extract_tokens_from_wav(wav_path, session):
    """
    从音频文件提取 speech token,支持超过 30s 的音频(分段处理)。
    返回 List[int]
    """
    speech = load_wav(wav_path, 16000)          # [1, N] at 16kHz
    total_samples = speech.shape[1]
    max_samples = WHISPER_MAX_SEC * 16000

    if total_samples <= max_samples:
        chunks = [speech]
    else:
        # 按整段切分,最后一段可能不足 30s
        n_chunks = math.ceil(total_samples / max_samples)
        chunks = [speech[:, i * max_samples: (i + 1) * max_samples] for i in range(n_chunks)]

    all_tokens = []
    in_name0 = session.get_inputs()[0].name
    in_name1 = session.get_inputs()[1].name
    for chunk in chunks:
        feat = whisper.log_mel_spectrogram(chunk, n_mels=128)   # [1, 128, T]
        tokens = session.run(None, {
            in_name0: feat.detach().cpu().numpy(),
            in_name1: np.array([feat.shape[2]], dtype=np.int32),
        })[0].flatten().tolist()
        all_tokens.extend(tokens)
    return all_tokens


def flow_inference(flow, gen_token, embedding, prompt_token, prompt_feat, device):
    """兼容 MaskedDiffWithXvec (v1) 和 CausalMaskedDiffWithXvec (v2/v3)"""
    flow_cls = type(flow).__name__
    common = dict(
        token=gen_token.to(device),
        token_len=torch.tensor([gen_token.shape[1]], dtype=torch.int32).to(device),
        prompt_token=prompt_token.to(device),
        prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(device),
        prompt_feat=prompt_feat.to(device),
        prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(device),
        embedding=embedding.to(device),
    )
    if flow_cls == 'MaskedDiffWithXvec':          # v1: 有 flow_cache,无 streaming/finalize
        mel, _ = flow.inference(**common, flow_cache=torch.zeros(0))
    else:                                          # v2/v3: CausalMaskedDiffWithXvec
        mel, _ = flow.inference(**common, streaming=False, finalize=True)
    return mel


def hift_inference(hift, tts_mel):
    """兼容 HiFTGenerator (v1/v2) 和 CausalHiFTGenerator (v3)"""
    hift_cls = type(hift).__name__
    if hift_cls == 'CausalHiFTGenerator':          # v3
        speech, _ = hift.inference(speech_feat=tts_mel, finalize=True)
    else:                                           # v1/v2: HiFTGenerator, returns (speech, source)
        speech, _ = hift.inference(speech_feat=tts_mel)
    return speech


def token2wav(flow, hift, gen_token, embedding, prompt_token, prompt_feat, device):
    with torch.inference_mode():
        tts_mel = flow_inference(flow, gen_token, embedding, prompt_token, prompt_feat, device)
        tts_speech = hift_inference(hift, tts_mel)
    return tts_speech.cpu()


def load_utt2spk(data_dir):
    utt2spk = {}
    with open(os.path.join(data_dir, 'utt2spk'), 'r') as f:
        for line in f:
            parts = line.strip().split(None, 1)
            if len(parts) == 2:
                utt2spk[parts[0]] = parts[1]
    spk2utt = {}
    for utt, spk in utt2spk.items():
        spk2utt.setdefault(spk, []).append(utt)
    return utt2spk, spk2utt


def pick_ref_utt(utt_id, spk2utt, utt2spk, valid_set):
    """从同说话人中随机选一条不同的有效参考 utt"""
    spk = utt2spk.get(utt_id)
    if spk is None:
        return None
    candidates = [u for u in spk2utt.get(spk, []) if u != utt_id and u in valid_set]
    return random.choice(candidates) if candidates else None


def main():
    parser = argparse.ArgumentParser(description='从训练数据的 speech token 重建音频(同说话人 prompt)')
    parser.add_argument('--model_dir', default=COSYVOICE_MODEL_DIR)
    parser.add_argument('--data_dir', default=DATA_DIR)
    parser.add_argument('--utt_id', default=None, help='指定 utt id')
    parser.add_argument('--num_utts', type=int, default=5)
    parser.add_argument('--output_dir', default='recon_output')
    parser.add_argument('--max_prompt_frames', type=int, default=150, help='prompt mel 最大帧数 (约 3s)')
    parser.add_argument('--extract_tokens', action='store_true',
                        help='实时用模型 tokenizer 提取 token,而不是读预计算的 utt2speech_token.pt')
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    random.seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('加载 flow + hift 模型...')
    flow, hift, configs = load_flow_hift(args.model_dir, device)
    sample_rate = configs.get('sample_rate', 24000)
    print(f'  加载完成 (device={device})')

    # ONNX sessions(按需加载)
    tokenizer_session = None
    embedding_session = None
    if args.extract_tokens:
        print('加载 speech tokenizer...')
        tokenizer_candidates = [
            os.path.join(args.model_dir, 'speech_tokenizer_v2.onnx'),
            os.path.join(args.model_dir, 'speech_tokenizer_v3.onnx'),
        ]
        tokenizer_path = next((p for p in tokenizer_candidates if os.path.exists(p)), None)
        if tokenizer_path is None:
            raise FileNotFoundError(f'找不到 speech tokenizer,已尝试: {tokenizer_candidates}')
        tokenizer_session = build_onnx_session(tokenizer_path)
        print('加载 campplus embedding 模型...')
        embedding_session = build_onnx_session(os.path.join(args.model_dir, 'campplus.onnx'))

    print(f'加载训练数据: {args.data_dir}')

    wav_scp = {}
    with open(os.path.join(args.data_dir, 'wav.scp'), 'r') as f:
        for line in f:
            parts = line.strip().split(None, 1)
            if len(parts) == 2:
                wav_scp[parts[0]] = parts[1]

    utt2embedding = {}
    utt2token = {}
    utt2spk, spk2utt = {}, {}

    if args.extract_tokens:
        # extract 模式:只需 wav.scp,token 和 embedding 都实时提取
        # utt2spk 可选:有则用同说话人 prompt,无则用自身前缀
        utt2spk_path = os.path.join(args.data_dir, 'utt2spk')
        if os.path.exists(utt2spk_path):
            utt2spk, spk2utt = load_utt2spk(args.data_dir)
        valid_set = set(wav_scp.keys())
    else:
        # 预计算模式:需要 utt2embedding.pt + utt2speech_token.pt + utt2spk
        utt2embedding = torch.load(os.path.join(args.data_dir, 'utt2embedding.pt'), map_location='cpu')
        utt2token = torch.load(os.path.join(args.data_dir, 'utt2speech_token.pt'), map_location='cpu')
        utt2spk, spk2utt = load_utt2spk(args.data_dir)
        valid_set = set(utt2embedding.keys()) & set(wav_scp.keys()) & set(utt2token.keys())

    if args.utt_id:
        utt_ids = [args.utt_id]
    else:
        utt_ids = sorted(valid_set)[:args.num_utts]

    print(f'  wav: {len(wav_scp)} 条')
    if not args.extract_tokens:
        print(f'  embedding: {len(utt2embedding)} 条, token: {len(utt2token)} 条')
    else:
        print(f'  (token + embedding 将实时提取)')
    print(f'  将重建 {len(utt_ids)} 条: {utt_ids}')

    os.makedirs(args.output_dir, exist_ok=True)

    for utt_id in utt_ids:
        print(f'\n--- {utt_id} ---')

        # ── 当前 utt 的 token ──────────────────────────────────────────────
        if args.extract_tokens:
            wav_path = wav_scp.get(utt_id)
            if not wav_path or not os.path.exists(wav_path):
                print(f'  跳过: 音频不存在 {wav_path}')
                continue
            print(f'  提取 token: {wav_path}')
            tokens = extract_tokens_from_wav(wav_path, tokenizer_session)
        else:
            tokens = utt2token[utt_id]
        gen_token = torch.tensor([tokens], dtype=torch.int32)
        print(f'  gen tokens: {gen_token.shape[1]}')

        # ── embedding ──────────────────────────────────────────────────────
        if args.extract_tokens:
            wav_path_for_emb = wav_scp.get(utt_id)
            print(f'  提取 embedding: {wav_path_for_emb}')
            emb = extract_embedding_from_wav(wav_path_for_emb, embedding_session)
        else:
            emb = utt2embedding[utt_id]
        embedding = torch.tensor([emb])

        # ── 选同说话人参考 utt 作为 prompt ─────────────────────────────────
        ref_utt = pick_ref_utt(utt_id, spk2utt, utt2spk, valid_set)

        if ref_utt is None:
            # 回退:用自身前缀
            print('  警告: 无同说话人参考,用自身前缀作 prompt')
            wav_path = wav_scp.get(utt_id)
            if wav_path and os.path.exists(wav_path):
                prompt_feat = extract_speech_feat(wav_path, configs['feat_extractor'])
                if prompt_feat.shape[1] > args.max_prompt_frames:
                    prompt_feat = prompt_feat[:, :args.max_prompt_frames, :]
            else:
                prompt_feat = torch.zeros(1, 0, 80)
            prompt_token_len = prompt_feat.shape[1] // flow.token_mel_ratio
            prompt_token = gen_token[:, :prompt_token_len]
            gen_token = gen_token[:, prompt_token_len:]
            if gen_token.shape[1] == 0:
                print('  跳过: 音频太短')
                continue
        else:
            print(f'  prompt utt (同说话人): {ref_utt}')
            ref_wav_path = wav_scp.get(ref_utt)

            # prompt mel
            if ref_wav_path and os.path.exists(ref_wav_path):
                prompt_feat = extract_speech_feat(ref_wav_path, configs['feat_extractor'])
                if prompt_feat.shape[1] > args.max_prompt_frames:
                    prompt_feat = prompt_feat[:, :args.max_prompt_frames, :]
            else:
                print(f'  警告: 参考音频不存在 {ref_wav_path},使用空 prompt')
                prompt_feat = torch.zeros(1, 0, 80)

            # prompt token
            prompt_token_len = prompt_feat.shape[1] // flow.token_mel_ratio
            if args.extract_tokens:
                ref_tokens = extract_tokens_from_wav(ref_wav_path, tokenizer_session)
            else:
                ref_tokens = utt2token.get(ref_utt, [])
            ref_token_tensor = torch.tensor([ref_tokens], dtype=torch.int32)
            prompt_token = ref_token_tensor[:, :prompt_token_len]

        print(f'  prompt: {prompt_token.shape[1]} tokens / {prompt_feat.shape[1]} mel frames, '
              f'gen: {gen_token.shape[1]} tokens')

        # ── 重建 ────────────────────────────────────────────────────────────
        tts_speech = token2wav(flow, hift, gen_token, embedding, prompt_token, prompt_feat, device)

        out_path = os.path.join(args.output_dir, f'{utt_id}_recon.wav')
        torchaudio.save(out_path, tts_speech.reshape(1, -1), sample_rate)
        print(f'  重建: {out_path} ({tts_speech.shape[-1] / sample_rate:.2f}s)')

        # 保存原始音频方便对比
        wav_path = wav_scp.get(utt_id)
        if wav_path and os.path.exists(wav_path):
            orig_out = os.path.join(args.output_dir, f'{utt_id}_orig.wav')
            orig = load_wav(wav_path, sample_rate)
            torchaudio.save(orig_out, orig, sample_rate)
            print(f'  原始: {orig_out} ({orig.shape[1] / sample_rate:.2f}s)')

    print(f'\n完成! 结果在 {args.output_dir}/')


if __name__ == '__main__':
    main()
    
# uv run python test_token_reconstruction.py --num_utts 10 --extract_tokens

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions