renderer.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from __future__ import annotations
  2. import re
  3. from typing import Dict, Iterable, Optional
  4. from jinja2 import BaseLoader, Environment, StrictUndefined
  5. from jinja2.exceptions import TemplateError
  6. from fastapi import HTTPException
  7. from ..schemas import (
  8. TemplatePreviewRequest,
  9. TemplateRecord,
  10. TemplateRenderRequest,
  11. VariableDefinition,
  12. )
  13. class TemplateRenderer:
  14. """Render templates with strict undefined variables to catch errors early."""
  15. def __init__(self) -> None:
  16. self.env = Environment(
  17. loader=BaseLoader(),
  18. undefined=StrictUndefined,
  19. trim_blocks=True,
  20. lstrip_blocks=True,
  21. autoescape=False,
  22. )
  23. def render_from_record(self, record: TemplateRecord, parameters: dict) -> str:
  24. normalized = self._normalize_parameters(parameters, record.input_variables)
  25. return self._render(record.template_body, normalized)
  26. def preview(self, payload: TemplatePreviewRequest) -> str:
  27. normalized = self._normalize_parameters(payload.parameters)
  28. return self._render(payload.template_body, normalized)
  29. def _render(self, template_body: str, parameters: dict) -> str:
  30. try:
  31. template = self.env.from_string(template_body)
  32. rendered = template.render(**parameters)
  33. paragraphs = [p.strip() for p in rendered.split("\n\n") if p.strip()]
  34. return "\n\n".join(paragraphs)
  35. except TemplateError as exc:
  36. raise HTTPException(status_code=400, detail=f"模板渲染出错: {exc}") from exc
  37. def _normalize_parameters(
  38. self,
  39. parameters: Dict[str, object],
  40. definitions: Optional[Iterable[VariableDefinition]] = None,
  41. ) -> Dict[str, object]:
  42. type_map = {}
  43. for definition in definitions or []:
  44. # Accept both pydantic models and plain dicts from persisted records.
  45. if isinstance(definition, VariableDefinition):
  46. name = (definition.name or "").strip()
  47. data_type = (definition.data_type or "").lower()
  48. elif isinstance(definition, dict):
  49. name = (definition.get("name") or "").strip()
  50. data_type = (definition.get("data_type") or "").lower()
  51. else:
  52. continue
  53. if name:
  54. type_map[name] = data_type
  55. return {
  56. key: self._convert_value(value, type_map.get(key))
  57. for key, value in parameters.items()
  58. }
  59. def _convert_value(self, value: object, data_type: Optional[str]) -> object:
  60. if value is None or isinstance(value, (int, float, bool)):
  61. return value
  62. if isinstance(value, str):
  63. stripped = value.strip()
  64. if stripped == "":
  65. return None
  66. if (data_type or "").startswith("bool"):
  67. lowered = stripped.lower()
  68. if lowered in {"true", "1", "yes", "y", "on"}:
  69. return True
  70. if lowered in {"false", "0", "no", "n", "off"}:
  71. return False
  72. if (data_type or "") in {"number", "integer"} or self._looks_numeric(
  73. stripped
  74. ):
  75. try:
  76. return int(stripped) if stripped.isdigit() else float(stripped)
  77. except ValueError:
  78. return value
  79. return value
  80. def _looks_numeric(self, value: str) -> bool:
  81. return bool(re.fullmatch(r"[+-]?\d+(\.\d+)?", value))
  82. renderer = TemplateRenderer()