1#!/usr/bin/env python3
  2import sys
  3import json
  4import argparse
  5import jinja2.ext as jinja2_ext
  6from PySide6.QtWidgets import (
  7    QApplication,
  8    QMainWindow,
  9    QWidget,
 10    QVBoxLayout,
 11    QHBoxLayout,
 12    QLabel,
 13    QPlainTextEdit,
 14    QTextEdit,
 15    QPushButton,
 16    QFileDialog,
 17)
 18from PySide6.QtGui import QColor, QColorConstants, QTextCursor, QTextFormat
 19from PySide6.QtCore import Qt, QRect, QSize
 20from jinja2 import TemplateSyntaxError
 21from jinja2.sandbox import ImmutableSandboxedEnvironment
 22from datetime import datetime
 23
 24
 25def format_template_content(template_content):
 26    """Format the Jinja template content using Jinja2's lexer."""
 27    if not template_content.strip():
 28        return template_content
 29
 30    env = ImmutableSandboxedEnvironment()
 31    tc_rstrip = template_content.rstrip()
 32    tokens = list(env.lex(tc_rstrip))
 33    result = ""
 34    indent_level = 0
 35    i = 0
 36
 37    while i < len(tokens):
 38        token = tokens[i]
 39        _, token_type, token_value = token
 40
 41        if token_type == "block_begin":
 42            block_start = i
 43            # Collect all tokens for this block construct
 44            construct_content = token_value
 45            end_token_type = token_type.replace("_begin", "_end")
 46            j = i + 1
 47            while j < len(tokens) and tokens[j][1] != end_token_type:
 48                construct_content += tokens[j][2]
 49                j += 1
 50
 51            if j < len(tokens):  # Found the end token
 52                construct_content += tokens[j][2]
 53                i = j  # Skip to the end token
 54
 55                # Check for control structure keywords for indentation
 56                stripped_content = construct_content.strip()
 57                instr = block_start + 1
 58                while tokens[instr][1] == "whitespace":
 59                    instr = instr + 1
 60
 61                instruction_token = tokens[instr][2]
 62                start_control_tokens = ["if", "for", "macro", "call", "block"]
 63                end_control_tokens = ["end" + t for t in start_control_tokens]
 64                is_control_start = any(
 65                    instruction_token.startswith(kw) for kw in start_control_tokens
 66                )
 67                is_control_end = any(
 68                    instruction_token.startswith(kw) for kw in end_control_tokens
 69                )
 70
 71                # Adjust indentation for control structures
 72                # For control end blocks, decrease indent BEFORE adding the content
 73                if is_control_end:
 74                    indent_level = max(0, indent_level - 1)
 75
 76                # Remove all previous whitespace before this block
 77                result = result.rstrip()
 78
 79                # Add proper indent, but only if this is not the first token
 80                added_newline = False
 81                if result:  # Only add newline and indent if there's already content
 82                    result += (
 83                        "\n" + "  " * indent_level
 84                    )  # Use 2 spaces per indent level
 85                    added_newline = True
 86                else:  # For the first token, don't add any indent
 87                    result += ""
 88
 89                # Add the block content
 90                result += stripped_content
 91
 92                # Add '-' after '%' if it wasn't there and we added a newline or indent
 93                if (
 94                    added_newline
 95                    and stripped_content.startswith("{%")
 96                    and not stripped_content.startswith("{%-")
 97                ):
 98                    # Add '-' at the beginning
 99                    result = (
100                        result[: result.rfind("{%")]
101                        + "{%-"
102                        + result[result.rfind("{%") + 2 :]
103                    )
104                if stripped_content.endswith("%}") and not stripped_content.endswith(
105                    "-%}"
106                ):
107                    # Only add '-' if this is not the last token or if there's content after
108                    if i + 1 < len(tokens) and tokens[i + 1][1] != "eof":
109                        result = result[:-2] + "-%}"
110
111                # For control start blocks, increase indent AFTER adding the content
112                if is_control_start:
113                    indent_level += 1
114            else:
115                # Malformed template, just add the token
116                result += token_value
117        elif token_type == "variable_begin":
118            # Collect all tokens for this variable construct
119            construct_content = token_value
120            end_token_type = token_type.replace("_begin", "_end")
121            j = i + 1
122            while j < len(tokens) and tokens[j][1] != end_token_type:
123                construct_content += tokens[j][2]
124                j += 1
125
126            if j < len(tokens):  # Found the end token
127                construct_content += tokens[j][2]
128                i = j  # Skip to the end token
129
130                # For variable constructs, leave them alone
131                # Do not add indent or whitespace before or after them
132                result += construct_content
133            else:
134                # Malformed template, just add the token
135                result += token_value
136        elif token_type == "data":
137            # Handle data (text between Jinja constructs)
138            # For data content, preserve it as is
139            result += token_value
140        else:
141            # Handle any other tokens
142            result += token_value
143
144        i += 1
145
146    # Clean up trailing newlines and spaces
147    result = result.rstrip()
148
149    # Copy the newline / space count from the original
150    if (trailing_length := len(template_content) - len(tc_rstrip)):
151        result += template_content[-trailing_length:]
152
153    return result
154
155
156# ------------------------
157# Line Number Widget
158# ------------------------
159class LineNumberArea(QWidget):
160    def __init__(self, editor):
161        super().__init__(editor)
162        self.code_editor = editor
163
164    def sizeHint(self):
165        return QSize(self.code_editor.line_number_area_width(), 0)
166
167    def paintEvent(self, event):
168        self.code_editor.line_number_area_paint_event(event)
169
170
171class CodeEditor(QPlainTextEdit):
172    def __init__(self):
173        super().__init__()
174        self.line_number_area = LineNumberArea(self)
175
176        self.blockCountChanged.connect(self.update_line_number_area_width)
177        self.updateRequest.connect(self.update_line_number_area)
178        self.cursorPositionChanged.connect(self.highlight_current_line)
179
180        self.update_line_number_area_width(0)
181        self.highlight_current_line()
182
183    def line_number_area_width(self):
184        digits = len(str(self.blockCount()))
185        space = 3 + self.fontMetrics().horizontalAdvance("9") * digits
186        return space
187
188    def update_line_number_area_width(self, _):
189        self.setViewportMargins(self.line_number_area_width(), 0, 0, 0)
190
191    def update_line_number_area(self, rect, dy):
192        if dy:
193            self.line_number_area.scroll(0, dy)
194        else:
195            self.line_number_area.update(
196                0, rect.y(), self.line_number_area.width(), rect.height()
197            )
198
199        if rect.contains(self.viewport().rect()):
200            self.update_line_number_area_width(0)
201
202    def resizeEvent(self, event):
203        super().resizeEvent(event)
204        cr = self.contentsRect()
205        self.line_number_area.setGeometry(
206            QRect(cr.left(), cr.top(), self.line_number_area_width(), cr.height())
207        )
208
209    def line_number_area_paint_event(self, event):
210        from PySide6.QtGui import QPainter
211
212        painter = QPainter(self.line_number_area)
213        painter.fillRect(event.rect(), QColorConstants.LightGray)
214
215        block = self.firstVisibleBlock()
216        block_number = block.blockNumber()
217        top = int(
218            self.blockBoundingGeometry(block).translated(self.contentOffset()).top()
219        )
220        bottom = top + int(self.blockBoundingRect(block).height())
221
222        while block.isValid() and top <= event.rect().bottom():
223            if block.isVisible() and bottom >= event.rect().top():
224                number = str(block_number + 1)
225                painter.setPen(QColorConstants.Black)
226                painter.drawText(
227                    0,
228                    top,
229                    self.line_number_area.width() - 2,
230                    self.fontMetrics().height(),
231                    Qt.AlignmentFlag.AlignRight,
232                    number,
233                )
234            block = block.next()
235            top = bottom
236            bottom = top + int(self.blockBoundingRect(block).height())
237            block_number += 1
238
239    def highlight_current_line(self):
240        extra_selections = []
241        if not self.isReadOnly():
242            selection = QTextEdit.ExtraSelection()
243            line_color = QColorConstants.Yellow.lighter(160)
244            selection.format.setBackground(line_color)  # pyright: ignore[reportAttributeAccessIssue]
245            selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True)  # pyright: ignore[reportAttributeAccessIssue]
246            selection.cursor = self.textCursor()  # pyright: ignore[reportAttributeAccessIssue]
247            selection.cursor.clearSelection()  # pyright: ignore[reportAttributeAccessIssue]
248            extra_selections.append(selection)
249        self.setExtraSelections(extra_selections)
250
251    def highlight_position(self, lineno: int, col: int, color: QColor):
252        block = self.document().findBlockByLineNumber(lineno - 1)
253        if block.isValid():
254            cursor = QTextCursor(block)
255            text = block.text()
256            start = block.position() + max(0, col - 1)
257            cursor.setPosition(start)
258            if col <= len(text):
259                cursor.movePosition(
260                    QTextCursor.MoveOperation.NextCharacter,
261                    QTextCursor.MoveMode.KeepAnchor,
262                )
263
264            extra = QTextEdit.ExtraSelection()
265            extra.format.setBackground(color.lighter(160))  # pyright: ignore[reportAttributeAccessIssue]
266            extra.cursor = cursor  # pyright: ignore[reportAttributeAccessIssue]
267
268            self.setExtraSelections(self.extraSelections() + [extra])
269
270    def highlight_line(self, lineno: int, color: QColor):
271        block = self.document().findBlockByLineNumber(lineno - 1)
272        if block.isValid():
273            cursor = QTextCursor(block)
274            cursor.select(QTextCursor.SelectionType.LineUnderCursor)
275
276            extra = QTextEdit.ExtraSelection()
277            extra.format.setBackground(color.lighter(160))  # pyright: ignore[reportAttributeAccessIssue]
278            extra.cursor = cursor  # pyright: ignore[reportAttributeAccessIssue]
279
280            self.setExtraSelections(self.extraSelections() + [extra])
281
282    def clear_highlighting(self):
283        self.highlight_current_line()
284
285
286# ------------------------
287# Main App
288# ------------------------
289class JinjaTester(QMainWindow):
290    def __init__(self):
291        super().__init__()
292        self.setWindowTitle("Jinja Template Tester")
293        self.resize(1200, 800)
294
295        central = QWidget()
296        main_layout = QVBoxLayout(central)
297
298        # -------- Top input area --------
299        input_layout = QHBoxLayout()
300
301        # Template editor with label
302        template_layout = QVBoxLayout()
303        template_label = QLabel("Jinja2 Template")
304        template_layout.addWidget(template_label)
305        self.template_edit = CodeEditor()
306        template_layout.addWidget(self.template_edit)
307        input_layout.addLayout(template_layout)
308
309        # JSON editor with label
310        json_layout = QVBoxLayout()
311        json_label = QLabel("Context (JSON)")
312        json_layout.addWidget(json_label)
313        self.json_edit = CodeEditor()
314        self.json_edit.setPlainText("""
315{
316    "add_generation_prompt": true,
317    "bos_token": "",
318    "eos_token": "",
319    "messages": [
320        {
321            "role": "user",
322            "content": "What is the capital of Poland?"
323        }
324    ]
325}
326        """.strip())
327        json_layout.addWidget(self.json_edit)
328        input_layout.addLayout(json_layout)
329
330        main_layout.addLayout(input_layout)
331
332        # -------- Rendered output area --------
333        output_label = QLabel("Rendered Output")
334        main_layout.addWidget(output_label)
335        self.output_edit = QPlainTextEdit()
336        self.output_edit.setReadOnly(True)
337        main_layout.addWidget(self.output_edit)
338
339        # -------- Render button and status --------
340        btn_layout = QHBoxLayout()
341
342        # Load template button
343        self.load_btn = QPushButton("Load Template")
344        self.load_btn.clicked.connect(self.load_template)
345        btn_layout.addWidget(self.load_btn)
346
347        # Format template button
348        self.format_btn = QPushButton("Format")
349        self.format_btn.clicked.connect(self.format_template)
350        btn_layout.addWidget(self.format_btn)
351
352        self.render_btn = QPushButton("Render")
353        self.render_btn.clicked.connect(self.render_template)
354        btn_layout.addWidget(self.render_btn)
355        main_layout.addLayout(btn_layout)
356
357        # Status label below buttons
358        self.status_label = QLabel("Ready")
359        main_layout.addWidget(self.status_label)
360
361        self.setCentralWidget(central)
362
363    def render_template(self):
364        self.template_edit.clear_highlighting()
365        self.output_edit.clear()
366
367        template_str = self.template_edit.toPlainText()
368        json_str = self.json_edit.toPlainText()
369
370        # Parse JSON context
371        try:
372            context = json.loads(json_str) if json_str.strip() else {}
373        except Exception as e:
374            self.status_label.setText(f"❌ JSON Error: {e}")
375            return
376
377        def raise_exception(text: str) -> str:
378            raise RuntimeError(text)
379
380        env = ImmutableSandboxedEnvironment(
381            trim_blocks=True,
382            lstrip_blocks=True,
383            extensions=[jinja2_ext.loopcontrols],
384        )
385        env.filters["tojson"] = (
386            lambda x,
387            indent=None,
388            separators=None,
389            sort_keys=False,
390            ensure_ascii=False: json.dumps(
391                x,
392                indent=indent,
393                separators=separators,
394                sort_keys=sort_keys,
395                ensure_ascii=ensure_ascii,
396            )
397        )
398        env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
399        env.globals["raise_exception"] = raise_exception
400        try:
401            template = env.from_string(template_str)
402            output = template.render(context)
403            self.output_edit.setPlainText(output)
404            self.status_label.setText("✅ Render successful")
405        except TemplateSyntaxError as e:
406            self.status_label.setText(f"❌ Syntax Error (line {e.lineno}): {e.message}")
407            if e.lineno:
408                self.template_edit.highlight_line(e.lineno, QColor("red"))
409        except Exception as e:
410            # Catch all runtime errors
411            # Try to extract template line number
412            lineno = None
413            tb = e.__traceback__
414            while tb:
415                frame = tb.tb_frame
416                if frame.f_code.co_filename == "<template>":
417                    lineno = tb.tb_lineno
418                    break
419                tb = tb.tb_next
420
421            error_msg = f"Runtime Error: {type(e).__name__}: {e}"
422            if lineno:
423                error_msg = f"Runtime Error at line {lineno} in template: {type(e).__name__}: {e}"
424                self.template_edit.highlight_line(lineno, QColor("orange"))
425
426            self.output_edit.setPlainText(error_msg)
427            self.status_label.setText(f"❌ {error_msg}")
428
429    def load_template(self):
430        """Load a Jinja template from a file using a file dialog."""
431        file_path, _ = QFileDialog.getOpenFileName(
432            self,
433            "Load Jinja Template",
434            "",
435            "Template Files (*.jinja *.j2 *.html *.txt);;All Files (*)",
436        )
437
438        if file_path:
439            try:
440                with open(file_path, "r", encoding="utf-8") as file:
441                    content = file.read()
442                    self.template_edit.setPlainText(content)
443                    self.status_label.setText(f"✅ Loaded template from {file_path}")
444            except Exception as e:
445                self.status_label.setText(f"❌ Error loading file: {str(e)}")
446
447    def format_template(self):
448        """Format the Jinja template using Jinja2's lexer for proper parsing."""
449        try:
450            template_content = self.template_edit.toPlainText()
451            if not template_content.strip():
452                self.status_label.setText("⚠️ Template is empty")
453                return
454
455            formatted_content = format_template_content(template_content)
456            self.template_edit.setPlainText(formatted_content)
457            self.status_label.setText("✅ Template formatted")
458        except Exception as e:
459            self.status_label.setText(f"❌ Error formatting template: {str(e)}")
460
461
462if __name__ == "__main__":
463    if len(sys.argv) > 1:
464        # CLI mode
465        parser = argparse.ArgumentParser(description="Jinja Template Tester")
466        parser.add_argument(
467            "--template", required=True, help="Path to Jinja template file"
468        )
469        parser.add_argument("--context", required=True, help="JSON string for context")
470        parser.add_argument(
471            "--action",
472            choices=["format", "render"],
473            default="render",
474            help="Action to perform",
475        )
476        args = parser.parse_args()
477
478        # Load template
479        with open(args.template, "r", encoding="utf-8") as f:
480            template_content = f.read()
481
482        # Load JSON
483        context = json.loads(args.context)
484        # Add missing variables
485        context.setdefault("bos_token", "")
486        context.setdefault("eos_token", "")
487        context.setdefault("add_generation_prompt", False)
488
489        env = ImmutableSandboxedEnvironment()
490
491        if args.action == "format":
492            formatted = format_template_content(template_content)
493            print(formatted) # noqa: NP100
494        elif args.action == "render":
495            template = env.from_string(template_content)
496            output = template.render(context)
497            print(output) # noqa: NP100
498
499    else:
500        # GUI mode
501        app = QApplication(sys.argv)
502        window = JinjaTester()
503        window.show()
504        sys.exit(app.exec())