1#!/usr/bin/env python3
  2
  3import argparse
  4import json
  5import os
  6import re
  7import sys
  8from pathlib import Path
  9from typing import Optional
 10from safetensors import safe_open
 11
 12
 13MODEL_SAFETENSORS_FILE = "model.safetensors"
 14MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
 15
 16
 17def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
 18    index_file = model_path / MODEL_SAFETENSORS_INDEX
 19
 20    if index_file.exists():
 21        with open(index_file, 'r') as f:
 22            index = json.load(f)
 23            return index.get("weight_map", {})
 24
 25    return None
 26
 27
 28def get_all_tensor_names(model_path: Path) -> list[str]:
 29    weight_map = get_weight_map(model_path)
 30
 31    if weight_map is not None:
 32        return list(weight_map.keys())
 33
 34    single_file = model_path / MODEL_SAFETENSORS_FILE
 35    if single_file.exists():
 36        try:
 37            with safe_open(single_file, framework="pt", device="cpu") as f:
 38                return list(f.keys())
 39        except Exception as e:
 40            print(f"Error reading {single_file}: {e}")
 41            sys.exit(1)
 42
 43    print(f"Error: No safetensors files found in {model_path}")
 44    sys.exit(1)
 45
 46
 47def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
 48    weight_map = get_weight_map(model_path)
 49
 50    if weight_map is not None:
 51        return weight_map.get(tensor_name)
 52
 53    single_file = model_path / MODEL_SAFETENSORS_FILE
 54    if single_file.exists():
 55        return single_file.name
 56
 57    return None
 58
 59
 60def normalize_tensor_name(tensor_name: str) -> str:
 61    normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
 62    normalized = re.sub(r'\.\d+$', '.#', normalized)
 63    return normalized
 64
 65
 66def list_all_tensors(model_path: Path, unique: bool = False):
 67    tensor_names = get_all_tensor_names(model_path)
 68
 69    if unique:
 70        seen = set()
 71        for tensor_name in sorted(tensor_names):
 72            normalized = normalize_tensor_name(tensor_name)
 73            if normalized not in seen:
 74                seen.add(normalized)
 75                print(normalized)
 76    else:
 77        for tensor_name in sorted(tensor_names):
 78            print(tensor_name)
 79
 80
 81def print_tensor_info(model_path: Path, tensor_name: str):
 82    tensor_file = find_tensor_file(model_path, tensor_name)
 83
 84    if tensor_file is None:
 85        print(f"Error: Could not find tensor '{tensor_name}' in model index")
 86        print(f"Model path: {model_path}")
 87        sys.exit(1)
 88
 89    file_path = model_path / tensor_file
 90
 91    try:
 92        with safe_open(file_path, framework="pt", device="cpu") as f:
 93            if tensor_name in f.keys():
 94                tensor_slice = f.get_slice(tensor_name)
 95                shape = tensor_slice.get_shape()
 96                print(f"Tensor: {tensor_name}")
 97                print(f"File:   {tensor_file}")
 98                print(f"Shape:  {shape}")
 99            else:
100                print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
101                sys.exit(1)
102
103    except FileNotFoundError:
104        print(f"Error: The file '{file_path}' was not found.")
105        sys.exit(1)
106    except Exception as e:
107        print(f"An error occurred: {e}")
108        sys.exit(1)
109
110
111def main():
112    parser = argparse.ArgumentParser(
113        description="Print tensor information from a safetensors model"
114    )
115    parser.add_argument(
116        "tensor_name",
117        nargs="?",  # optional (if --list is used for example)
118        help="Name of the tensor to inspect"
119    )
120    parser.add_argument(
121        "-m", "--model-path",
122        type=Path,
123        help="Path to the model directory (default: MODEL_PATH environment variable)"
124    )
125    parser.add_argument(
126        "-l", "--list",
127        action="store_true",
128        help="List unique tensor patterns in the model (layer numbers replaced with #)"
129    )
130
131    args = parser.parse_args()
132
133    model_path = args.model_path
134    if model_path is None:
135        model_path_str = os.environ.get("MODEL_PATH")
136        if model_path_str is None:
137            print("Error: --model-path not provided and MODEL_PATH environment variable not set")
138            sys.exit(1)
139        model_path = Path(model_path_str)
140
141    if not model_path.exists():
142        print(f"Error: Model path does not exist: {model_path}")
143        sys.exit(1)
144
145    if not model_path.is_dir():
146        print(f"Error: Model path is not a directory: {model_path}")
147        sys.exit(1)
148
149    if args.list:
150        list_all_tensors(model_path, unique=True)
151    else:
152        if args.tensor_name is None:
153            print("Error: tensor_name is required when not using --list")
154            sys.exit(1)
155        print_tensor_info(model_path, args.tensor_name)
156
157
158if __name__ == "__main__":
159    main()