1#!/usr/bin/env python3
2
3import argparse
4import os
5import importlib
6import torch
7import numpy as np
8
9from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
10from pathlib import Path
11
12unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
13
14parser = argparse.ArgumentParser(description='Process model with specified path')
15parser.add_argument('--model-path', '-m', help='Path to the model')
16args = parser.parse_args()
17
18model_path = os.environ.get('MODEL_PATH', args.model_path)
19if model_path is None:
20 parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
21
22config = AutoConfig.from_pretrained(model_path)
23
24print("Model type: ", config.model_type)
25print("Vocab size: ", config.vocab_size)
26print("Hidden size: ", config.hidden_size)
27print("Number of layers: ", config.num_hidden_layers)
28print("BOS token id: ", config.bos_token_id)
29print("EOS token id: ", config.eos_token_id)
30
31print("Loading model and tokenizer using AutoTokenizer:", model_path)
32tokenizer = AutoTokenizer.from_pretrained(model_path)
33
34if unreleased_model_name:
35 model_name_lower = unreleased_model_name.lower()
36 unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
37 class_name = f"{unreleased_model_name}ForCausalLM"
38 print(f"Importing unreleased model module: {unreleased_module_path}")
39
40 try:
41 model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
42 model = model_class.from_pretrained(model_path)
43 except (ImportError, AttributeError) as e:
44 print(f"Failed to import or load model: {e}")
45 print("Falling back to AutoModelForCausalLM")
46 model = AutoModelForCausalLM.from_pretrained(model_path)
47else:
48 model = AutoModelForCausalLM.from_pretrained(model_path)
49print(f"Model class: {type(model)}")
50#print(f"Model file: {type(model).__module__}")
51
52model_name = os.path.basename(model_path)
53print(f"Model name: {model_name}")
54
55prompt = "Hello world today"
56input_ids = tokenizer(prompt, return_tensors="pt").input_ids
57print(f"Input tokens: {input_ids}")
58print(f"Input text: {repr(prompt)}")
59print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
60
61with torch.no_grad():
62 outputs = model(input_ids, output_hidden_states=True)
63
64 # Extract hidden states from the last layer
65 # outputs.hidden_states is a tuple of (num_layers + 1) tensors
66 # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size]
67 last_hidden_states = outputs.hidden_states[-1]
68
69 # Get embeddings for all tokens
70 token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension
71
72 print(f"Hidden states shape: {last_hidden_states.shape}")
73 print(f"Token embeddings shape: {token_embeddings.shape}")
74 print(f"Hidden dimension: {token_embeddings.shape[-1]}")
75 print(f"Number of tokens: {token_embeddings.shape[0]}")
76
77 # Save raw token embeddings
78 data_dir = Path("data")
79 data_dir.mkdir(exist_ok=True)
80 bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
81 txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
82
83 # Save all token embeddings as binary
84 print(token_embeddings)
85 token_embeddings.astype(np.float32).tofile(bin_filename)
86
87 # Save as text for inspection
88 with open(txt_filename, "w") as f:
89 for i, embedding in enumerate(token_embeddings):
90 for j, val in enumerate(embedding):
91 f.write(f"{i} {j} {val:.6f}\n")
92
93 # Print embeddings per token in the requested format
94 print("\nToken embeddings:")
95 tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
96 for i, embedding in enumerate(token_embeddings):
97 # Format: show first few values, ..., then last few values
98 if len(embedding) > 10:
99 # Show first 3 and last 3 values with ... in between
100 first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3])
101 last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:])
102 print(f"embedding {i}: {first_vals} ... {last_vals}")
103 else:
104 # If embedding is short, show all values
105 vals = " ".join(f"{val:8.6f}" for val in embedding)
106 print(f"embedding {i}: {vals}")
107
108 # Also show token info for reference
109 print(f"\nToken reference:")
110 for i, token in enumerate(tokens):
111 print(f" Token {i}: {repr(token)}")
112
113 print(f"Saved bin logits to: {bin_filename}")
114 print(f"Saved txt logist to: {txt_filename}")