| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from __future__ import annotations
- import re
- from typing import Dict, Iterable, Optional
- from jinja2 import BaseLoader, Environment, StrictUndefined
- from jinja2.exceptions import TemplateError
- from fastapi import HTTPException
- from ..schemas import (
- TemplatePreviewRequest,
- TemplateRecord,
- TemplateRenderRequest,
- VariableDefinition,
- )
- class TemplateRenderer:
- """Render templates with strict undefined variables to catch errors early."""
- def __init__(self) -> None:
- self.env = Environment(
- loader=BaseLoader(),
- undefined=StrictUndefined,
- trim_blocks=True,
- lstrip_blocks=True,
- autoescape=False,
- )
- def render_from_record(self, record: TemplateRecord, parameters: dict) -> str:
- normalized = self._normalize_parameters(parameters, record.input_variables)
- return self._render(record.template_body, normalized)
- def preview(self, payload: TemplatePreviewRequest) -> str:
- normalized = self._normalize_parameters(payload.parameters)
- return self._render(payload.template_body, normalized)
- def _render(self, template_body: str, parameters: dict) -> str:
- try:
- template = self.env.from_string(template_body)
- rendered = template.render(**parameters)
- paragraphs = [p.strip() for p in rendered.split("\n\n") if p.strip()]
- return "\n\n".join(paragraphs)
- except TemplateError as exc:
- raise HTTPException(status_code=400, detail=f"模板渲染出错: {exc}") from exc
- def _normalize_parameters(
- self,
- parameters: Dict[str, object],
- definitions: Optional[Iterable[VariableDefinition]] = None,
- ) -> Dict[str, object]:
- type_map = {}
- for definition in definitions or []:
- # Accept both pydantic models and plain dicts from persisted records.
- if isinstance(definition, VariableDefinition):
- name = (definition.name or "").strip()
- data_type = (definition.data_type or "").lower()
- elif isinstance(definition, dict):
- name = (definition.get("name") or "").strip()
- data_type = (definition.get("data_type") or "").lower()
- else:
- continue
- if name:
- type_map[name] = data_type
- return {
- key: self._convert_value(value, type_map.get(key))
- for key, value in parameters.items()
- }
- def _convert_value(self, value: object, data_type: Optional[str]) -> object:
- if value is None or isinstance(value, (int, float, bool)):
- return value
- if isinstance(value, str):
- stripped = value.strip()
- if stripped == "":
- return None
- if (data_type or "").startswith("bool"):
- lowered = stripped.lower()
- if lowered in {"true", "1", "yes", "y", "on"}:
- return True
- if lowered in {"false", "0", "no", "n", "off"}:
- return False
- if (data_type or "") in {"number", "integer"} or self._looks_numeric(
- stripped
- ):
- try:
- return int(stripped) if stripped.isdigit() else float(stripped)
- except ValueError:
- return value
- return value
- def _looks_numeric(self, value: str) -> bool:
- return bool(re.fullmatch(r"[+-]?\d+(\.\d+)?", value))
- renderer = TemplateRenderer()
|