benchmark_onnx_models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import argparse
  2. import json
  3. import os
  4. import statistics
  5. import sys
  6. import time
  7. from typing import Dict, List
  8. TEXT_CASES = {
  9. "short": "Hello, this is a short benchmark sentence.",
  10. "medium": "Today we are benchmarking two ONNX speech models to compare latency, memory usage, and audio duration under the same voice and speed settings.",
  11. "long": "This benchmark uses a longer paragraph so we can observe how each ONNX model behaves when the text becomes more realistic for production usage. We want to compare cold start cost, steady state inference time, real time factor, and memory footprint, while keeping the voice and speed fixed to reduce noise in the results.",
  12. }
  13. def current_rss_mb() -> float:
  14. with open("/proc/self/status", "r", encoding="utf-8") as f:
  15. for line in f:
  16. if line.startswith("VmRSS:"):
  17. parts = line.split()
  18. return int(parts[1]) / 1024.0
  19. return 0.0
  20. def run_case(model_name: str, text: str, voice: str, speed: float, warm_runs: int, measured_runs: int) -> Dict:
  21. import numpy as np
  22. import speech_tts_onnx_opt as tts
  23. start_rss_mb = current_rss_mb()
  24. load_start = time.perf_counter()
  25. tts.load_model(force_reload=True, name=model_name)
  26. load_elapsed_s = time.perf_counter() - load_start
  27. post_load_rss_mb = current_rss_mb()
  28. warm_times: List[float] = []
  29. for _ in range(warm_runs):
  30. t0 = time.perf_counter()
  31. audio = tts.synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
  32. warm_times.append(time.perf_counter() - t0)
  33. run_times: List[float] = []
  34. audio_durations: List[float] = []
  35. peak_rss_mb = post_load_rss_mb
  36. last_audio = None
  37. for _ in range(measured_runs):
  38. t0 = time.perf_counter()
  39. audio = tts.synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
  40. elapsed = time.perf_counter() - t0
  41. run_times.append(elapsed)
  42. duration = float(len(audio) / tts.sample_rate)
  43. audio_durations.append(duration)
  44. peak_rss_mb = max(peak_rss_mb, current_rss_mb())
  45. last_audio = audio
  46. mean_time_s = statistics.mean(run_times)
  47. mean_audio_s = statistics.mean(audio_durations)
  48. rtf = mean_time_s / mean_audio_s if mean_audio_s > 0 else None
  49. result = {
  50. "model_name": model_name,
  51. "voice": voice,
  52. "speed": speed,
  53. "text_chars": len(text),
  54. "text_words": len(text.split()),
  55. "sample_rate": tts.sample_rate,
  56. "load_time_s": load_elapsed_s,
  57. "warmup_time_s_mean": statistics.mean(warm_times) if warm_times else None,
  58. "run_time_s": run_times,
  59. "run_time_s_mean": mean_time_s,
  60. "run_time_s_median": statistics.median(run_times),
  61. "run_time_s_min": min(run_times),
  62. "run_time_s_max": max(run_times),
  63. "audio_duration_s_mean": mean_audio_s,
  64. "rtf_mean": rtf,
  65. "rss_mb_before_load": start_rss_mb,
  66. "rss_mb_after_load": post_load_rss_mb,
  67. "rss_mb_peak": peak_rss_mb,
  68. "audio_samples_last_run": int(len(last_audio)) if last_audio is not None else 0,
  69. "audio_abs_max_last_run": float(np.max(np.abs(last_audio))) if last_audio is not None and len(last_audio) else 0.0,
  70. }
  71. return result
  72. def main() -> int:
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument("--model", required=True)
  75. parser.add_argument("--text", required=True)
  76. parser.add_argument("--voice", default="af_heart")
  77. parser.add_argument("--speed", type=float, default=1.0)
  78. parser.add_argument("--warm-runs", type=int, default=1)
  79. parser.add_argument("--runs", type=int, default=3)
  80. args = parser.parse_args()
  81. result = run_case(
  82. model_name=args.model,
  83. text=args.text,
  84. voice=args.voice,
  85. speed=args.speed,
  86. warm_runs=args.warm_runs,
  87. measured_runs=args.runs,
  88. )
  89. json.dump(result, sys.stdout, ensure_ascii=False)
  90. sys.stdout.write("\n")
  91. return 0
  92. if __name__ == "__main__":
  93. raise SystemExit(main())