1#!/usr/bin/env python3
  2"""
  3# Generated by Claude AI
  4
  5Script to completely regenerate the GGML remoting codebase from YAML configuration.
  6
  7This script reads api_functions.yaml and regenerates all the header files and
  8implementation templates for the GGML remoting layer.
  9
 10Usage:
 11  python regenerate_remoting.py
 12
 13The script will:
 141. Read ggmlremoting_functions.yaml configuration
 152. Generate updated header files
 163. Generate implementation templates in dedicated files
 174. Show a summary of what was generated
 18"""
 19
 20import yaml
 21from typing import Dict, List, Any
 22from pathlib import Path
 23import os
 24import subprocess
 25import shutil
 26import logging
 27
 28NL = '\n' # can't have f"{'\n'}" in f-strings
 29
 30
 31class RemotingCodebaseGenerator:
 32    def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"):
 33        """Initialize the generator with the YAML configuration."""
 34        self.yaml_path = yaml_path
 35
 36        if not Path(yaml_path).exists():
 37            raise FileNotFoundError(f"Configuration file {yaml_path} not found")
 38
 39        with open(yaml_path, 'r') as f:
 40            self.config = yaml.safe_load(f)
 41
 42        self.functions = self.config['functions']
 43        self.naming_patterns = self.config['naming_patterns']
 44        self.config_data = self.config['config']
 45
 46        # Check if clang-format is available
 47        self.clang_format_available = self._check_clang_format_available()
 48
 49    def _check_clang_format_available(self) -> bool:
 50        """Check if clang-format is available in the system PATH."""
 51        return shutil.which("clang-format") is not None
 52
 53    def _format_file_with_clang_format(self, file_path: Path) -> bool:
 54        """Format a file with clang-format -i. Returns True if successful, False otherwise."""
 55        if not self.clang_format_available:
 56            return False
 57
 58        try:
 59            subprocess.run(
 60                ["clang-format", "-i", str(file_path)],
 61                check=True,
 62                capture_output=True,
 63                text=True
 64            )
 65            return True
 66        except subprocess.CalledProcessError:
 67            logging.exception(f"   ⚠️  clang-format failed for {file_path}")
 68            return False
 69        except Exception as e:
 70            logging.exception(f"   ⚠️  Unexpected error formatting {file_path}: {e}")
 71            return False
 72
 73    def generate_enum_name(self, group_name: str, function_name: str) -> str:
 74        """Generate the APIR_COMMAND_TYPE enum name for a function."""
 75        prefix = self.naming_patterns['enum_prefix']
 76        return f"{prefix}{group_name.upper()}_{function_name.upper()}"
 77
 78    def generate_backend_function_name(self, group_name: str, function_name: str) -> str:
 79        """Generate the backend function name."""
 80        function_key = f"{group_name}_{function_name}"
 81        overrides = self.naming_patterns.get('backend_function_overrides', {})
 82
 83        if function_key in overrides:
 84            return overrides[function_key]
 85
 86        prefix = self.naming_patterns['backend_function_prefix']
 87        return f"{prefix}{group_name}_{function_name}"
 88
 89    def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:
 90        """Generate the frontend function name."""
 91        prefix = self.naming_patterns['frontend_function_prefix']
 92        return f"{prefix}{group_name}_{function_name}"
 93
 94    def get_enabled_functions(self) -> List[Dict[str, Any]]:
 95        """Get all enabled functions with their metadata."""
 96        functions = []
 97        enum_value = 0
 98
 99        for group_name, group_data in self.functions.items():
100            group_description = group_data['group_description']
101
102            for function_name, func_metadata in group_data['functions'].items():
103                # Handle case where func_metadata is None or empty (functions with only comments)
104                if func_metadata is None:
105                    func_metadata = {}
106
107                # Functions are enabled by default unless explicitly disabled
108                if func_metadata.get('enabled', True):
109                    functions.append({
110                        'group_name': group_name,
111                        'function_name': function_name,
112                        'enum_name': self.generate_enum_name(group_name, function_name),
113                        'enum_value': enum_value,
114                        'backend_function': self.generate_backend_function_name(group_name, function_name),
115                        'frontend_function': self.generate_frontend_function_name(group_name, function_name),
116                        'frontend_return': func_metadata.get('frontend_return', 'void'),
117                        'frontend_extra_params': func_metadata.get('frontend_extra_params', []),
118                        'group_description': group_description,
119                        'deprecated': func_metadata.get('deprecated', False),
120                    })
121                    enum_value += 1
122
123        return functions
124
125    def generate_apir_backend_header(self) -> str:
126        """Generate the complete apir_backend.h file."""
127        functions = self.get_enabled_functions()
128
129        # Generate the enum section
130        enum_lines = ["typedef enum ApirBackendCommandType {"]
131        current_group = None
132
133        for func in functions:
134            # Add comment for new group
135            if func['group_name'] != current_group:
136                enum_lines.append("")
137                enum_lines.append(f"  /* {func['group_description']} */")
138                current_group = func['group_name']
139
140            enum_lines.append(f"  {func['enum_name']} = {func['enum_value']},")
141
142        # Add the count
143        total_count = len(functions)
144        enum_lines.append("\n  // last command_type index + 1")
145        enum_lines.append(f"  APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
146        enum_lines.append("} ApirBackendCommandType;")
147
148        # Full header template
149        header_content = NL.join(enum_lines) + "\n"
150
151        return header_content
152
153    def generate_backend_dispatched_header(self) -> str:
154        """Generate the complete backend-dispatched.h file."""
155        functions = self.get_enabled_functions()
156
157        # Function declarations
158        decl_lines = []
159        current_group = None
160
161        for func in functions:
162            if func['group_name'] != current_group:
163                decl_lines.append(f"\n/* {func['group_description']} */")
164                current_group = func['group_name']
165
166            signature = "uint32_t"
167            params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx"
168            if func['deprecated']:
169                decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */")
170
171            decl_lines.append(f"{signature} {func['backend_function']}({params});")
172
173        # Switch cases
174        switch_lines = []
175        current_group = None
176
177        for func in functions:
178            if func['group_name'] != current_group:
179                switch_lines.append(f"  /* {func['group_description']} */")
180                current_group = func['group_name']
181
182            deprecated = " (DEPRECATED)" if func['deprecated'] else ""
183
184            switch_lines.append(f"  case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
185
186        # Dispatch table
187        table_lines = []
188        current_group = None
189
190        for func in functions:
191            if func['group_name'] != current_group:
192                table_lines.append(f"\n  /* {func['group_description']} */")
193                table_lines.append("")
194                current_group = func['group_name']
195
196            deprecated = " /* DEPRECATED */" if func['deprecated'] else ""
197            table_lines.append(f"  /* {func['enum_name']}  = */ {func['backend_function']}{deprecated},")
198
199        header_content = f'''\
200#pragma once
201
202{NL.join(decl_lines)}
203
204static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
205{{
206  switch (type) {{
207{NL.join(switch_lines)}
208
209  default: return "unknown";
210  }}
211}}
212
213extern "C" {{
214static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
215  {NL.join(table_lines)}
216}};
217}}
218'''
219        return header_content
220
221    def generate_virtgpu_forward_header(self) -> str:
222        """Generate the complete virtgpu-forward.gen.h file."""
223        functions = self.get_enabled_functions()
224
225        decl_lines = []
226        current_group = None
227
228        for func in functions:
229            if func['group_name'] != current_group:
230                decl_lines.append("")
231                decl_lines.append(f"/* {func['group_description']} */")
232                current_group = func['group_name']
233
234            if func['deprecated']:
235                decl_lines.append(f"/* {func['frontend_function']} is deprecated. */")
236                continue
237
238            # Build parameter list
239            params = [self.naming_patterns['frontend_base_param']]
240            params.extend(func['frontend_extra_params'])
241            param_str = ', '.join(params)
242
243            decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});")
244
245        header_content = f'''\
246#pragma once
247{NL.join(decl_lines)}
248'''
249        return header_content
250
251    def regenerate_codebase(self) -> None:
252        """Regenerate the entire remoting codebase."""
253        logging.info("🔄 Regenerating GGML Remoting Codebase...")
254        logging.info("=" * 50)
255
256        # Detect if we're running from frontend directory
257        current_dir = os.getcwd()
258        is_frontend_dir = current_dir.endswith('ggml-virtgpu')
259
260        if is_frontend_dir:
261            # Running from ggml/src/ggml-virtgpu-apir
262            logging.info("📍 Detected frontend directory execution")
263            frontend_base = Path(".")
264        else:
265            # Running from project root (fallback to original behavior)
266            logging.info("📍 Detected project root execution")
267            base_path = self.config_data.get('base_path', 'ggml/src')
268            frontend_base = Path(base_path) / "ggml-virtgpu"
269
270        # Compute final file paths
271        backend_base = frontend_base / "backend"
272        apir_backend_path = backend_base / "shared" / "apir_backend.gen.h"
273        backend_dispatched_path = backend_base / "backend-dispatched.gen.h"
274        virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h"
275
276        # Create output directories for each file
277        apir_backend_path.parent.mkdir(parents=True, exist_ok=True)
278        backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)
279        virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)
280
281        # Generate header files
282        logging.info("📁 Generating header files...")
283
284        apir_backend_content = self.generate_apir_backend_header()
285        apir_backend_path.write_text(apir_backend_content)
286        logging.info(f"   ✅ {apir_backend_path.resolve()}")
287
288        backend_dispatched_content = self.generate_backend_dispatched_header()
289        backend_dispatched_path.write_text(backend_dispatched_content)
290        logging.info(f"   ✅ {backend_dispatched_path.resolve()}")
291
292        virtgpu_forward_content = self.generate_virtgpu_forward_header()
293        virtgpu_forward_path.write_text(virtgpu_forward_content)
294        logging.info(f"   ✅ {virtgpu_forward_path.resolve()}")
295
296        # Format generated files with clang-format
297        generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]
298
299        if not self.clang_format_available:
300            logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n"
301                            "   Install clang-format to enable automatic code formatting.")
302        else:
303            logging.info("\n🎨 Formatting files with clang-format...")
304            for file_path in generated_files:
305                if self._format_file_with_clang_format(file_path):
306                    logging.info(f"   ✅ Formatted {file_path.name}")
307                else:
308                    logging.warning(f"   ❌ Failed to format {file_path.name}")
309
310        # Generate summary
311        functions = self.get_enabled_functions()
312        total_functions = len(functions)
313
314        logging.info("\n📊 Generation Summary:")
315        logging.info("=" * 50)
316        logging.info(f"   Total functions: {total_functions}")
317        logging.info(f"   Function groups: {len(self.functions)}")
318        logging.info("   Header files: 3")
319        logging.info(f"   Working directory: {current_dir}")
320
321
322def main():
323    try:
324        generator = RemotingCodebaseGenerator()
325        generator.regenerate_codebase()
326    except Exception as e:
327        logging.exception(f"❌ Error: {e}")
328        exit(1)
329
330
331if __name__ == "__main__":
332    main()