1#!/usr/bin/env python3
  2
  3import argparse
  4import os
  5import sys
  6import importlib
  7
  8from transformers import AutoTokenizer, AutoConfig, AutoModel
  9import torch
 10
 11# Add parent directory to path for imports
 12sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
 13from utils.common import save_output_data
 14
 15
 16def parse_arguments():
 17    parser = argparse.ArgumentParser(description='Run original embedding model')
 18    parser.add_argument(
 19        '--model-path',
 20        '-m',
 21        help='Path to the model'
 22    )
 23    parser.add_argument(
 24        '--prompts-file',
 25        '-p',
 26        help='Path to file containing prompts (one per line)'
 27    )
 28    parser.add_argument(
 29        '--use-sentence-transformers',
 30        action='store_true',
 31        help=('Use SentenceTransformer to apply all numbered layers '
 32              '(01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
 33    )
 34    parser.add_argument(
 35        '--device',
 36        '-d',
 37        help='Device to use (cpu, cuda, mps, auto)',
 38        default='auto'
 39    )
 40    return parser.parse_args()
 41
 42
 43def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device="auto"):
 44    if device == "cpu":
 45        device_map = {"": "cpu"}
 46        print("Forcing CPU usage")
 47    elif device == "auto":
 48        # On Mac, "auto" device_map can cause issues with accelerate
 49        # So we detect the best device manually
 50        if torch.cuda.is_available():
 51            device_map = {"": "cuda"}
 52            print("Using CUDA")
 53        elif torch.backends.mps.is_available():
 54            device_map = {"": "mps"}
 55            print("Using MPS (Apple Metal)")
 56        else:
 57            device_map = {"": "cpu"}
 58            print("Using CPU")
 59    else:
 60        device_map = {"": device}
 61
 62    if use_sentence_transformers:
 63        from sentence_transformers import SentenceTransformer
 64        print("Using SentenceTransformer to apply all numbered layers")
 65        model = SentenceTransformer(model_path)
 66        tokenizer = model.tokenizer
 67        config = model[0].auto_model.config  # type: ignore
 68    else:
 69        tokenizer = AutoTokenizer.from_pretrained(model_path)
 70        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 71
 72        # This can be used to override the sliding window size for manual testing. This
 73        # can be useful to verify the sliding window attention mask in the original model
 74        # and compare it with the converted .gguf model.
 75        if hasattr(config, 'sliding_window'):
 76            original_sliding_window = config.sliding_window
 77            print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
 78
 79        unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
 80        print(f"Using unreleased model: {unreleased_model_name}")
 81        if unreleased_model_name:
 82            model_name_lower = unreleased_model_name.lower()
 83            unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
 84            class_name = f"{unreleased_model_name}Model"
 85            print(f"Importing unreleased model module: {unreleased_module_path}")
 86
 87            try:
 88                model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
 89                model = model_class.from_pretrained(
 90                    model_path,
 91                    device_map=device_map,
 92                    offload_folder="offload",
 93                    trust_remote_code=True,
 94                    config=config
 95                )
 96            except (ImportError, AttributeError) as e:
 97                print(f"Failed to import or load model: {e}")
 98                sys.exit(1)
 99        else:
100            model = AutoModel.from_pretrained(
101                model_path,
102                device_map=device_map,
103                offload_folder="offload",
104                trust_remote_code=True,
105                config=config
106            )
107        print(f"Model class: {type(model)}")
108        print(f"Model file: {type(model).__module__}")
109
110        # Verify the model is using the correct sliding window
111        if hasattr(model.config, 'sliding_window'):  # type: ignore
112            print(f"Model's sliding_window: {model.config.sliding_window}")  # type: ignore
113        else:
114            print("Model config does not have sliding_window attribute")
115
116    return model, tokenizer, config
117
118
119def get_prompt(args):
120    if args.prompts_file:
121        try:
122            with open(args.prompts_file, 'r', encoding='utf-8') as f:
123                return f.read().strip()
124        except FileNotFoundError:
125            print(f"Error: Prompts file '{args.prompts_file}' not found")
126            sys.exit(1)
127        except Exception as e:
128            print(f"Error reading prompts file: {e}")
129            sys.exit(1)
130    else:
131        return "Hello world today"
132
133
134def main():
135    args = parse_arguments()
136
137    model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
138    if model_path is None:
139        print("Error: Model path must be specified either via --model-path argument "
140              "or EMBEDDING_MODEL_PATH environment variable")
141        sys.exit(1)
142
143    # Determine if we should use SentenceTransformer
144    use_st = (
145        args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
146    )
147
148    model, tokenizer, config = load_model_and_tokenizer(model_path, use_st, args.device)
149
150    # Get the device the model is on
151    if not use_st:
152        device = next(model.parameters()).device
153    else:
154        # For SentenceTransformer, get device from the underlying model
155        device = next(model[0].auto_model.parameters()).device  # type: ignore
156
157    model_name = os.path.basename(model_path)
158
159    prompt_text = get_prompt(args)
160    texts = [prompt_text]
161
162    with torch.no_grad():
163        if use_st:
164            embeddings = model.encode(texts, convert_to_numpy=True)
165            all_embeddings = embeddings  # Shape: [batch_size, hidden_size]
166
167            encoded = tokenizer(
168                texts,
169                padding=True,
170                truncation=True,
171                return_tensors="pt"
172            )
173            tokens = encoded['input_ids'][0]
174            token_ids = tokens.cpu().tolist()
175            token_strings = tokenizer.convert_ids_to_tokens(tokens)
176            for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
177                print(f"{token_id:6d} -> '{token_str}'")
178
179            print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
180            print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")  # type: ignore
181        else:
182            # Standard approach: use base model output only
183            encoded = tokenizer(
184                texts,
185                padding=True,
186                truncation=True,
187                return_tensors="pt"
188            )
189
190            tokens = encoded['input_ids'][0]
191            token_ids = tokens.cpu().tolist()
192            token_strings = tokenizer.convert_ids_to_tokens(tokens)
193            for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
194                print(f"{token_id:6d} -> '{token_str}'")
195
196            # Move inputs to the same device as the model
197            encoded = {k: v.to(device) for k, v in encoded.items()}
198            outputs = model(**encoded)
199            hidden_states = outputs.last_hidden_state  # Shape: [batch_size, seq_len, hidden_size]
200
201            all_embeddings = hidden_states[0].float().cpu().numpy()  # Shape: [seq_len, hidden_size]
202
203            print(f"Hidden states shape: {hidden_states.shape}")
204            print(f"All embeddings shape: {all_embeddings.shape}")
205            print(f"Embedding dimension: {all_embeddings.shape[1]}")
206
207        if len(all_embeddings.shape) == 1:
208            n_embd = all_embeddings.shape[0]  # type: ignore
209            n_embd_count = 1
210            all_embeddings = all_embeddings.reshape(1, -1)
211        else:
212            n_embd = all_embeddings.shape[1]  # type: ignore
213            n_embd_count = all_embeddings.shape[0]  # type: ignore
214
215        print()
216
217        for j in range(n_embd_count):
218            embedding = all_embeddings[j]
219            print(f"embedding {j}: ", end="")
220
221            # Print first 3 values
222            for i in range(min(3, n_embd)):
223                print(f"{embedding[i]:9.6f} ", end="")
224
225            print(" ... ", end="")
226
227            # Print last 3 values
228            for i in range(n_embd - 3, n_embd):
229                print(f"{embedding[i]:9.6f} ", end="")
230
231            print()  # New line
232
233        print()
234
235        flattened_embeddings = all_embeddings.flatten()
236        print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
237        print("")
238
239        save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings")
240
241
242if __name__ == "__main__":
243    main()