1#!/usr/bin/env python3
  2
  3import argparse
  4import os
  5import importlib
  6import torch
  7import numpy as np
  8
  9from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
 10from pathlib import Path
 11
 12unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
 13
 14parser = argparse.ArgumentParser(description='Process model with specified path')
 15parser.add_argument('--model-path', '-m', help='Path to the model')
 16args = parser.parse_args()
 17
 18model_path = os.environ.get('MODEL_PATH', args.model_path)
 19if model_path is None:
 20    parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
 21
 22config = AutoConfig.from_pretrained(model_path)
 23
 24print("Model type:       ", config.model_type)
 25print("Vocab size:       ", config.vocab_size)
 26print("Hidden size:      ", config.hidden_size)
 27print("Number of layers: ", config.num_hidden_layers)
 28print("BOS token id:     ", config.bos_token_id)
 29print("EOS token id:     ", config.eos_token_id)
 30
 31print("Loading model and tokenizer using AutoTokenizer:", model_path)
 32tokenizer = AutoTokenizer.from_pretrained(model_path)
 33
 34if unreleased_model_name:
 35    model_name_lower = unreleased_model_name.lower()
 36    unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
 37    class_name = f"{unreleased_model_name}ForCausalLM"
 38    print(f"Importing unreleased model module: {unreleased_module_path}")
 39
 40    try:
 41        model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
 42        model = model_class.from_pretrained(model_path)
 43    except (ImportError, AttributeError) as e:
 44        print(f"Failed to import or load model: {e}")
 45        print("Falling back to AutoModelForCausalLM")
 46        model = AutoModelForCausalLM.from_pretrained(model_path)
 47else:
 48    model = AutoModelForCausalLM.from_pretrained(model_path)
 49print(f"Model class: {type(model)}")
 50#print(f"Model file: {type(model).__module__}")
 51
 52model_name = os.path.basename(model_path)
 53print(f"Model name: {model_name}")
 54
 55prompt = "Hello world today"
 56input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 57print(f"Input tokens: {input_ids}")
 58print(f"Input text: {repr(prompt)}")
 59print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
 60
 61with torch.no_grad():
 62    outputs = model(input_ids, output_hidden_states=True)
 63
 64    # Extract hidden states from the last layer
 65    # outputs.hidden_states is a tuple of (num_layers + 1) tensors
 66    # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size]
 67    last_hidden_states = outputs.hidden_states[-1]
 68
 69    # Get embeddings for all tokens
 70    token_embeddings = last_hidden_states[0].float().cpu().numpy()  # Remove batch dimension
 71
 72    print(f"Hidden states shape: {last_hidden_states.shape}")
 73    print(f"Token embeddings shape: {token_embeddings.shape}")
 74    print(f"Hidden dimension: {token_embeddings.shape[-1]}")
 75    print(f"Number of tokens: {token_embeddings.shape[0]}")
 76
 77    # Save raw token embeddings
 78    data_dir = Path("data")
 79    data_dir.mkdir(exist_ok=True)
 80    bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
 81    txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
 82
 83    # Save all token embeddings as binary
 84    print(token_embeddings)
 85    token_embeddings.astype(np.float32).tofile(bin_filename)
 86
 87    # Save as text for inspection
 88    with open(txt_filename, "w") as f:
 89        for i, embedding in enumerate(token_embeddings):
 90            for j, val in enumerate(embedding):
 91                f.write(f"{i} {j} {val:.6f}\n")
 92
 93    # Print embeddings per token in the requested format
 94    print("\nToken embeddings:")
 95    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 96    for i, embedding in enumerate(token_embeddings):
 97        # Format: show first few values, ..., then last few values
 98        if len(embedding) > 10:
 99            # Show first 3 and last 3 values with ... in between
100            first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3])
101            last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:])
102            print(f"embedding {i}: {first_vals}  ... {last_vals}")
103        else:
104            # If embedding is short, show all values
105            vals = " ".join(f"{val:8.6f}" for val in embedding)
106            print(f"embedding {i}: {vals}")
107
108    # Also show token info for reference
109    print(f"\nToken reference:")
110    for i, token in enumerate(tokens):
111        print(f"  Token {i}: {repr(token)}")
112
113    print(f"Saved bin logits to: {bin_filename}")
114    print(f"Saved txt logist to: {txt_filename}")