main2-old.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from fastapi import FastAPI, Request,File, UploadFile, HTTPException,Form, Response
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import JSONResponse
  4. from fastapi.responses import RedirectResponse
  5. from fastapi.responses import StreamingResponse
  6. from fastapi.staticfiles import StaticFiles
  7. import os
  8. import shutil
  9. import uuid
  10. from openai import OpenAI
  11. from pydantic import BaseModel
  12. import hashlib
  13. import asyncio
  14. from typing import AsyncGenerator
  15. from pydub import AudioSegment # 用于音频处理
  16. app = FastAPI()
  17. # 配置允许的跨域源,* 表示允许所有
  18. origins = [
  19. "*",
  20. # 若要限制特定域名,可以在这里添加,例如:
  21. # "http://localhost",
  22. # "http://localhost:8000",
  23. ]
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=origins, # 允许的来源
  27. allow_credentials=True,
  28. allow_methods=["*"], # 允许的方法
  29. allow_headers=["*"], # 允许的请求头
  30. )
  31. # 指定上传文件保存的目录
  32. UPLOAD_DIRECTORY = "static/files"
  33. if not os.path.exists(UPLOAD_DIRECTORY):
  34. os.makedirs(UPLOAD_DIRECTORY)
  35. # 配置静态文件服务,使上传的PDF可以通过URL访问
  36. app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
  37. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web") # 假设 viewer.html 在 static/web
  38. # 挂载静态文件
  39. app.mount("/static", StaticFiles(directory="static"), name="static")
  40. # 根路径重定向到 PDF.js viewer
  41. @app.get("/")
  42. def root():
  43. return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
  44. # 如果需要自定义 PDF 文件上传或动态渲染,可以在此添加更多路由
  45. def sanitize_filename(name: str) -> str:
  46. return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
  47. @app.post("/upload-pdf")
  48. async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
  49. if file.content_type != 'application/pdf':
  50. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  51. # 清理文件名
  52. sanitized_name = sanitize_filename(custom_name)
  53. if not sanitized_name:
  54. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  55. # 添加 .pdf 扩展名
  56. unique_filename = f"{sanitized_name}.pdf"
  57. file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
  58. # 如果文件已存在,添加 UUID 以确保唯一性
  59. if os.path.exists(file_path):
  60. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  61. try:
  62. with open(file_path, "wb") as buffer:
  63. shutil.copyfileobj(file.file, buffer)
  64. except Exception as e:
  65. raise HTTPException(status_code=500, detail="上传过程中出错")
  66. finally:
  67. file.file.close()
  68. # 构建文件的相对路径
  69. file_relative_path = f"/static/files/{unique_filename}"
  70. return JSONResponse(content={"success": True, "file_path": file_relative_path})
  71. @app.get("/list-pdfs")
  72. async def list_pdfs():
  73. try:
  74. files = os.listdir(UPLOAD_DIRECTORY)
  75. # 过滤出PDF文件并构建可访问的URL
  76. pdf_files = [
  77. {
  78. "name": file,
  79. "url": f"/static/files/{file}"
  80. }
  81. for file in files if file.lower().endswith(".pdf")
  82. ]
  83. return JSONResponse(content={"success": True, "files": pdf_files})
  84. except Exception as e:
  85. raise HTTPException(status_code=500, detail="无法获取文件列表")
  86. class TextToSpeechRequest(BaseModel):
  87. user_input: str
  88. # 配置OpenAI客户端
  89. api_key = "sk-bpaahUHgzoriWpjV24524eC7BbBf47D5A4Ce59EbFdB57f35" # 请确保使用环境变量存储API密钥
  90. client = OpenAI(
  91. base_url="https://api.wlai.vip/v1",
  92. api_key=api_key
  93. )
  94. # 音频缓存目录
  95. CACHE_DIR = "audio_cache"
  96. os.makedirs(CACHE_DIR, exist_ok=True)
  97. @app.post("/text-to-speech/")
  98. async def text_to_speech(request: TextToSpeechRequest):
  99. user_input = request.user_input
  100. try:
  101. # 生成文本的hash值作为缓存文件名
  102. # print(user_input)
  103. text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  104. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.mp3")
  105. if os.path.exists(audio_path):
  106. # 如果缓存存在,直接返回缓存的音频
  107. # print("have")
  108. with open(audio_path, "rb") as f:
  109. audio_data = f.read()
  110. return Response(content=audio_data, media_type="audio/mpeg")
  111. else:
  112. # 如果缓存不存在,调用OpenAI API生成音频
  113. with client.audio.speech.with_streaming_response.create(
  114. model="tts-1",
  115. voice="nova",
  116. input=user_input,
  117. ) as response:
  118. response.stream_to_file(audio_path)
  119. with open(audio_path, "rb") as f:
  120. audio_data = f.read()
  121. return Response(content=audio_data, media_type="audio/mpeg")
  122. except Exception as e:
  123. raise HTTPException(status_code=500, detail=str(e))
  124. # 整页阅读,分块
  125. # 最大字符数,以根据需求调整
  126. MAX_CHUNK_SIZE = 200 # 每个块的最大字符数
  127. def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
  128. """
  129. 将文本分割成不超过 max_chunk_size 个字符的块。
  130. 尝试在句号、感叹号或问号处断开,以避免中间断句。
  131. """
  132. import re
  133. sentences = re.split('(?<=[.!?]) +', text)
  134. chunks = []
  135. current_chunk = ""
  136. for sentence in sentences:
  137. if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
  138. current_chunk += " " + sentence if current_chunk else sentence
  139. else:
  140. if current_chunk:
  141. chunks.append(current_chunk)
  142. if len(sentence) > max_chunk_size:
  143. # 如果单个句子超过最大长度,强制分割
  144. for i in range(0, len(sentence), max_chunk_size):
  145. chunks.append(sentence[i:i + max_chunk_size])
  146. current_chunk = ""
  147. else:
  148. current_chunk = sentence
  149. if current_chunk:
  150. chunks.append(current_chunk)
  151. return chunks
  152. async def generate_tts_audio(chunk: str) -> str:
  153. """
  154. 生成给定文本块的语音音频,并缓存到文件系统中。
  155. 返回音频文件的路径。
  156. """
  157. text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
  158. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.mp3")
  159. if not os.path.exists(audio_path):
  160. try:
  161. with client.audio.speech.with_streaming_response.create(
  162. model="tts-1",
  163. voice="nova",
  164. input=chunk,
  165. ) as response:
  166. response.stream_to_file(audio_path)
  167. except Exception as e:
  168. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  169. return audio_path
  170. def concatenate_audios(audio_paths: list, output_path: str) -> None:
  171. """
  172. 将多个音频文件按顺序拼接成一个音频文件。
  173. """
  174. combined = AudioSegment.empty()
  175. for path in audio_paths:
  176. audio = AudioSegment.from_mp3(path)
  177. combined += audio
  178. combined.export(output_path, format="mp3")
  179. @app.post("/page-to-speech/")
  180. async def page_to_speech(request: TextToSpeechRequest):
  181. user_input = request.user_input.strip()
  182. if not user_input:
  183. raise HTTPException(status_code=400, detail="输入文本为空。")
  184. # 生成整个文本的hash值作为整体缓存文件名
  185. full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  186. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.mp3")
  187. if os.path.exists(full_audio_path):
  188. # 如果整体缓存存在,直接返回
  189. return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/mpeg")
  190. # 分割文本为多个块
  191. chunks = split_text_into_chunks(user_input)
  192. audio_paths = []
  193. async def audio_generator() -> AsyncGenerator[bytes, None]:
  194. for chunk in chunks:
  195. audio_path = await generate_tts_audio(chunk)
  196. audio_paths.append(audio_path)
  197. with open(audio_path, "rb") as f:
  198. yield f.read()
  199. await asyncio.sleep(0) # 让事件循环有机会处理其它任务
  200. # 异步生成并缓存整体音频
  201. async def create_full_audio():
  202. await asyncio.gather(*(generate_tts_audio(chunk) for chunk in chunks))
  203. concatenate_audios(audio_paths, full_audio_path)
  204. asyncio.create_task(create_full_audio())
  205. return StreamingResponse(audio_generator(), media_type="audio/mpeg")