| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import argparse
- import json
- import os
- import statistics
- import sys
- import time
- from typing import Dict, List
- TEXT_CASES = {
- "short": "Hello, this is a short benchmark sentence.",
- "medium": "Today we are benchmarking two ONNX speech models to compare latency, memory usage, and audio duration under the same voice and speed settings.",
- "long": "This benchmark uses a longer paragraph so we can observe how each ONNX model behaves when the text becomes more realistic for production usage. We want to compare cold start cost, steady state inference time, real time factor, and memory footprint, while keeping the voice and speed fixed to reduce noise in the results.",
- }
- def current_rss_mb() -> float:
- with open("/proc/self/status", "r", encoding="utf-8") as f:
- for line in f:
- if line.startswith("VmRSS:"):
- parts = line.split()
- return int(parts[1]) / 1024.0
- return 0.0
- def run_case(model_name: str, text: str, voice: str, speed: float, warm_runs: int, measured_runs: int) -> Dict:
- import numpy as np
- import speech_tts_onnx_opt as tts
- start_rss_mb = current_rss_mb()
- load_start = time.perf_counter()
- tts.load_model(force_reload=True, name=model_name)
- load_elapsed_s = time.perf_counter() - load_start
- post_load_rss_mb = current_rss_mb()
- warm_times: List[float] = []
- for _ in range(warm_runs):
- t0 = time.perf_counter()
- audio = tts.synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
- warm_times.append(time.perf_counter() - t0)
- run_times: List[float] = []
- audio_durations: List[float] = []
- peak_rss_mb = post_load_rss_mb
- last_audio = None
- for _ in range(measured_runs):
- t0 = time.perf_counter()
- audio = tts.synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
- elapsed = time.perf_counter() - t0
- run_times.append(elapsed)
- duration = float(len(audio) / tts.sample_rate)
- audio_durations.append(duration)
- peak_rss_mb = max(peak_rss_mb, current_rss_mb())
- last_audio = audio
- mean_time_s = statistics.mean(run_times)
- mean_audio_s = statistics.mean(audio_durations)
- rtf = mean_time_s / mean_audio_s if mean_audio_s > 0 else None
- result = {
- "model_name": model_name,
- "voice": voice,
- "speed": speed,
- "text_chars": len(text),
- "text_words": len(text.split()),
- "sample_rate": tts.sample_rate,
- "load_time_s": load_elapsed_s,
- "warmup_time_s_mean": statistics.mean(warm_times) if warm_times else None,
- "run_time_s": run_times,
- "run_time_s_mean": mean_time_s,
- "run_time_s_median": statistics.median(run_times),
- "run_time_s_min": min(run_times),
- "run_time_s_max": max(run_times),
- "audio_duration_s_mean": mean_audio_s,
- "rtf_mean": rtf,
- "rss_mb_before_load": start_rss_mb,
- "rss_mb_after_load": post_load_rss_mb,
- "rss_mb_peak": peak_rss_mb,
- "audio_samples_last_run": int(len(last_audio)) if last_audio is not None else 0,
- "audio_abs_max_last_run": float(np.max(np.abs(last_audio))) if last_audio is not None and len(last_audio) else 0.0,
- }
- return result
- def main() -> int:
- parser = argparse.ArgumentParser()
- parser.add_argument("--model", required=True)
- parser.add_argument("--text", required=True)
- parser.add_argument("--voice", default="af_heart")
- parser.add_argument("--speed", type=float, default=1.0)
- parser.add_argument("--warm-runs", type=int, default=1)
- parser.add_argument("--runs", type=int, default=3)
- args = parser.parse_args()
- result = run_case(
- model_name=args.model,
- text=args.text,
- voice=args.voice,
- speed=args.speed,
- warm_runs=args.warm_runs,
- measured_runs=args.runs,
- )
- json.dump(result, sys.stdout, ensure_ascii=False)
- sys.stdout.write("\n")
- return 0
- if __name__ == "__main__":
- raise SystemExit(main())
|