"""
测试 speech token 重建质量
用 CosyVoice3 的 flow + hift 重建音频。
默认模式: 从预计算的 utt2speech_token.pt + utt2embedding.pt 读取。
--extract_tokens 模式: 实时提取 token (speech_tokenizer) 和 embedding (campplus),
只需 wav.scp 即可运行,utt2spk 可选(有则用同说话人 prompt)。
用法:
# 重建前 5 条 (读预计算 token + embedding)
python test_token_reconstruction.py
# 实时提取 token + embedding,只需 wav.scp
python test_token_reconstruction.py --extract_tokens
# 指定 utt / 重建条数
python test_token_reconstruction.py --utt_id 0_320_5
python test_token_reconstruction.py --num_utts 10 --extract_tokens
"""
import argparse
import math
import os
import random
import sys
import numpy as np
import onnxruntime
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import whisper
from hyperpyyaml import load_hyperpyyaml
os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, '.')
COSYVOICE_MODEL_DIR = '../pretrained_models/cosyvoice2'
DATA_DIR = 'data/test'
# whisper tokenizer 单次最长 30s (对应 3000 帧)
WHISPER_MAX_SEC = 30
def load_wav(wav_path, target_sr):
speech, sr = torchaudio.load(wav_path)
if speech.shape[0] > 1:
speech = speech.mean(dim=0, keepdim=True)
if sr != target_sr:
speech = torchaudio.transforms.Resample(sr, target_sr)(speech)
return speech
def extract_speech_feat(wav_path, feat_extractor):
"""音频 → cosyvoice mel (80-bin, 24kHz), shape [1, T, 80]"""
speech = load_wav(wav_path, 24000)
speech_feat = feat_extractor(speech).squeeze(dim=0).transpose(0, 1)
return speech_feat.unsqueeze(dim=0)
def load_flow_hift(model_dir, device):
yaml_candidates = [
os.path.join(model_dir, 'cosyvoice2.yaml'),
os.path.join(model_dir, 'cosyvoice3.yaml'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), 'conf', 'cosyvoice2.yaml'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), 'conf', 'cosyvoice3.yaml'),
]
configs = None
last_err = None
used_yaml = None
for yaml_path in yaml_candidates:
if not os.path.exists(yaml_path):
continue
try:
with open(yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={
'llm': None,
'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN'),
})
used_yaml = yaml_path
break
except ImportError as e:
# Some released yaml files still reference matcha.*.
# Fallback to local train/conf/cosyvoice*.yaml in that case.
last_err = e
continue
if configs is None:
if last_err is not None:
raise last_err
raise FileNotFoundError(f'找不到可用的 cosyvoice yaml,已尝试: {yaml_candidates}')
print(f' config: {used_yaml}')
flow = configs['flow'].to(device).eval()
hift = configs['hift'].to(device).eval()
flow_state = torch.load(os.path.join(model_dir, 'flow.pt'), map_location=device)
flow.load_state_dict(flow_state, strict=True)
hift_state = {k.replace('generator.', ''): v for k, v in
torch.load(os.path.join(model_dir, 'hift.pt'), map_location=device).items()}
hift.load_state_dict(hift_state, strict=True)
return flow, hift, configs
def build_onnx_session(onnx_path):
"""加载 ONNX 模型(tokenizer 或 campplus embedding)"""
assert os.path.exists(onnx_path), f'找不到模型: {onnx_path}'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(onnx_path, sess_options=option, providers=providers)
print(f' loaded: {onnx_path}')
return session
def extract_embedding_from_wav(wav_path, session):
"""从音频文件提取 speaker embedding (campplus),返回 List[float]"""
audio = load_wav(wav_path, 16000) # [1, N] at 16kHz
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = session.run(
None, {session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
return embedding
def extract_tokens_from_wav(wav_path, session):
"""
从音频文件提取 speech token,支持超过 30s 的音频(分段处理)。
返回 List[int]
"""
speech = load_wav(wav_path, 16000) # [1, N] at 16kHz
total_samples = speech.shape[1]
max_samples = WHISPER_MAX_SEC * 16000
if total_samples <= max_samples:
chunks = [speech]
else:
# 按整段切分,最后一段可能不足 30s
n_chunks = math.ceil(total_samples / max_samples)
chunks = [speech[:, i * max_samples: (i + 1) * max_samples] for i in range(n_chunks)]
all_tokens = []
in_name0 = session.get_inputs()[0].name
in_name1 = session.get_inputs()[1].name
for chunk in chunks:
feat = whisper.log_mel_spectrogram(chunk, n_mels=128) # [1, 128, T]
tokens = session.run(None, {
in_name0: feat.detach().cpu().numpy(),
in_name1: np.array([feat.shape[2]], dtype=np.int32),
})[0].flatten().tolist()
all_tokens.extend(tokens)
return all_tokens
def flow_inference(flow, gen_token, embedding, prompt_token, prompt_feat, device):
"""兼容 MaskedDiffWithXvec (v1) 和 CausalMaskedDiffWithXvec (v2/v3)"""
flow_cls = type(flow).__name__
common = dict(
token=gen_token.to(device),
token_len=torch.tensor([gen_token.shape[1]], dtype=torch.int32).to(device),
prompt_token=prompt_token.to(device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(device),
prompt_feat=prompt_feat.to(device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(device),
embedding=embedding.to(device),
)
if flow_cls == 'MaskedDiffWithXvec': # v1: 有 flow_cache,无 streaming/finalize
mel, _ = flow.inference(**common, flow_cache=torch.zeros(0))
else: # v2/v3: CausalMaskedDiffWithXvec
mel, _ = flow.inference(**common, streaming=False, finalize=True)
return mel
def hift_inference(hift, tts_mel):
"""兼容 HiFTGenerator (v1/v2) 和 CausalHiFTGenerator (v3)"""
hift_cls = type(hift).__name__
if hift_cls == 'CausalHiFTGenerator': # v3
speech, _ = hift.inference(speech_feat=tts_mel, finalize=True)
else: # v1/v2: HiFTGenerator, returns (speech, source)
speech, _ = hift.inference(speech_feat=tts_mel)
return speech
def token2wav(flow, hift, gen_token, embedding, prompt_token, prompt_feat, device):
with torch.inference_mode():
tts_mel = flow_inference(flow, gen_token, embedding, prompt_token, prompt_feat, device)
tts_speech = hift_inference(hift, tts_mel)
return tts_speech.cpu()
def load_utt2spk(data_dir):
utt2spk = {}
with open(os.path.join(data_dir, 'utt2spk'), 'r') as f:
for line in f:
parts = line.strip().split(None, 1)
if len(parts) == 2:
utt2spk[parts[0]] = parts[1]
spk2utt = {}
for utt, spk in utt2spk.items():
spk2utt.setdefault(spk, []).append(utt)
return utt2spk, spk2utt
def pick_ref_utt(utt_id, spk2utt, utt2spk, valid_set):
"""从同说话人中随机选一条不同的有效参考 utt"""
spk = utt2spk.get(utt_id)
if spk is None:
return None
candidates = [u for u in spk2utt.get(spk, []) if u != utt_id and u in valid_set]
return random.choice(candidates) if candidates else None
def main():
parser = argparse.ArgumentParser(description='从训练数据的 speech token 重建音频(同说话人 prompt)')
parser.add_argument('--model_dir', default=COSYVOICE_MODEL_DIR)
parser.add_argument('--data_dir', default=DATA_DIR)
parser.add_argument('--utt_id', default=None, help='指定 utt id')
parser.add_argument('--num_utts', type=int, default=5)
parser.add_argument('--output_dir', default='recon_output')
parser.add_argument('--max_prompt_frames', type=int, default=150, help='prompt mel 最大帧数 (约 3s)')
parser.add_argument('--extract_tokens', action='store_true',
help='实时用模型 tokenizer 提取 token,而不是读预计算的 utt2speech_token.pt')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('加载 flow + hift 模型...')
flow, hift, configs = load_flow_hift(args.model_dir, device)
sample_rate = configs.get('sample_rate', 24000)
print(f' 加载完成 (device={device})')
# ONNX sessions(按需加载)
tokenizer_session = None
embedding_session = None
if args.extract_tokens:
print('加载 speech tokenizer...')
tokenizer_candidates = [
os.path.join(args.model_dir, 'speech_tokenizer_v2.onnx'),
os.path.join(args.model_dir, 'speech_tokenizer_v3.onnx'),
]
tokenizer_path = next((p for p in tokenizer_candidates if os.path.exists(p)), None)
if tokenizer_path is None:
raise FileNotFoundError(f'找不到 speech tokenizer,已尝试: {tokenizer_candidates}')
tokenizer_session = build_onnx_session(tokenizer_path)
print('加载 campplus embedding 模型...')
embedding_session = build_onnx_session(os.path.join(args.model_dir, 'campplus.onnx'))
print(f'加载训练数据: {args.data_dir}')
wav_scp = {}
with open(os.path.join(args.data_dir, 'wav.scp'), 'r') as f:
for line in f:
parts = line.strip().split(None, 1)
if len(parts) == 2:
wav_scp[parts[0]] = parts[1]
utt2embedding = {}
utt2token = {}
utt2spk, spk2utt = {}, {}
if args.extract_tokens:
# extract 模式:只需 wav.scp,token 和 embedding 都实时提取
# utt2spk 可选:有则用同说话人 prompt,无则用自身前缀
utt2spk_path = os.path.join(args.data_dir, 'utt2spk')
if os.path.exists(utt2spk_path):
utt2spk, spk2utt = load_utt2spk(args.data_dir)
valid_set = set(wav_scp.keys())
else:
# 预计算模式:需要 utt2embedding.pt + utt2speech_token.pt + utt2spk
utt2embedding = torch.load(os.path.join(args.data_dir, 'utt2embedding.pt'), map_location='cpu')
utt2token = torch.load(os.path.join(args.data_dir, 'utt2speech_token.pt'), map_location='cpu')
utt2spk, spk2utt = load_utt2spk(args.data_dir)
valid_set = set(utt2embedding.keys()) & set(wav_scp.keys()) & set(utt2token.keys())
if args.utt_id:
utt_ids = [args.utt_id]
else:
utt_ids = sorted(valid_set)[:args.num_utts]
print(f' wav: {len(wav_scp)} 条')
if not args.extract_tokens:
print(f' embedding: {len(utt2embedding)} 条, token: {len(utt2token)} 条')
else:
print(f' (token + embedding 将实时提取)')
print(f' 将重建 {len(utt_ids)} 条: {utt_ids}')
os.makedirs(args.output_dir, exist_ok=True)
for utt_id in utt_ids:
print(f'\n--- {utt_id} ---')
# ── 当前 utt 的 token ──────────────────────────────────────────────
if args.extract_tokens:
wav_path = wav_scp.get(utt_id)
if not wav_path or not os.path.exists(wav_path):
print(f' 跳过: 音频不存在 {wav_path}')
continue
print(f' 提取 token: {wav_path}')
tokens = extract_tokens_from_wav(wav_path, tokenizer_session)
else:
tokens = utt2token[utt_id]
gen_token = torch.tensor([tokens], dtype=torch.int32)
print(f' gen tokens: {gen_token.shape[1]}')
# ── embedding ──────────────────────────────────────────────────────
if args.extract_tokens:
wav_path_for_emb = wav_scp.get(utt_id)
print(f' 提取 embedding: {wav_path_for_emb}')
emb = extract_embedding_from_wav(wav_path_for_emb, embedding_session)
else:
emb = utt2embedding[utt_id]
embedding = torch.tensor([emb])
# ── 选同说话人参考 utt 作为 prompt ─────────────────────────────────
ref_utt = pick_ref_utt(utt_id, spk2utt, utt2spk, valid_set)
if ref_utt is None:
# 回退:用自身前缀
print(' 警告: 无同说话人参考,用自身前缀作 prompt')
wav_path = wav_scp.get(utt_id)
if wav_path and os.path.exists(wav_path):
prompt_feat = extract_speech_feat(wav_path, configs['feat_extractor'])
if prompt_feat.shape[1] > args.max_prompt_frames:
prompt_feat = prompt_feat[:, :args.max_prompt_frames, :]
else:
prompt_feat = torch.zeros(1, 0, 80)
prompt_token_len = prompt_feat.shape[1] // flow.token_mel_ratio
prompt_token = gen_token[:, :prompt_token_len]
gen_token = gen_token[:, prompt_token_len:]
if gen_token.shape[1] == 0:
print(' 跳过: 音频太短')
continue
else:
print(f' prompt utt (同说话人): {ref_utt}')
ref_wav_path = wav_scp.get(ref_utt)
# prompt mel
if ref_wav_path and os.path.exists(ref_wav_path):
prompt_feat = extract_speech_feat(ref_wav_path, configs['feat_extractor'])
if prompt_feat.shape[1] > args.max_prompt_frames:
prompt_feat = prompt_feat[:, :args.max_prompt_frames, :]
else:
print(f' 警告: 参考音频不存在 {ref_wav_path},使用空 prompt')
prompt_feat = torch.zeros(1, 0, 80)
# prompt token
prompt_token_len = prompt_feat.shape[1] // flow.token_mel_ratio
if args.extract_tokens:
ref_tokens = extract_tokens_from_wav(ref_wav_path, tokenizer_session)
else:
ref_tokens = utt2token.get(ref_utt, [])
ref_token_tensor = torch.tensor([ref_tokens], dtype=torch.int32)
prompt_token = ref_token_tensor[:, :prompt_token_len]
print(f' prompt: {prompt_token.shape[1]} tokens / {prompt_feat.shape[1]} mel frames, '
f'gen: {gen_token.shape[1]} tokens')
# ── 重建 ────────────────────────────────────────────────────────────
tts_speech = token2wav(flow, hift, gen_token, embedding, prompt_token, prompt_feat, device)
out_path = os.path.join(args.output_dir, f'{utt_id}_recon.wav')
torchaudio.save(out_path, tts_speech.reshape(1, -1), sample_rate)
print(f' 重建: {out_path} ({tts_speech.shape[-1] / sample_rate:.2f}s)')
# 保存原始音频方便对比
wav_path = wav_scp.get(utt_id)
if wav_path and os.path.exists(wav_path):
orig_out = os.path.join(args.output_dir, f'{utt_id}_orig.wav')
orig = load_wav(wav_path, sample_rate)
torchaudio.save(orig_out, orig, sample_rate)
print(f' 原始: {orig_out} ({orig.shape[1] / sample_rate:.2f}s)')
print(f'\n完成! 结果在 {args.output_dir}/')
if __name__ == '__main__':
main()
# uv run python test_token_reconstruction.py --num_utts 10 --extract_tokens
origin.wav
cosyvoice2.wav
cosyvoice3.wav
第一个是GT,转为tokens交给flow+vocoder输出真实音频,实际测试下来cosyvoice3对于喘息有比较明显的破音。
测试代码: