1#!/usr/bin/env python3
2
3import argparse
4import json
5import os
6import re
7import sys
8from pathlib import Path
9from typing import Optional
10from safetensors import safe_open
11
12
13MODEL_SAFETENSORS_FILE = "model.safetensors"
14MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
15
16
17def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
18 index_file = model_path / MODEL_SAFETENSORS_INDEX
19
20 if index_file.exists():
21 with open(index_file, 'r') as f:
22 index = json.load(f)
23 return index.get("weight_map", {})
24
25 return None
26
27
28def get_all_tensor_names(model_path: Path) -> list[str]:
29 weight_map = get_weight_map(model_path)
30
31 if weight_map is not None:
32 return list(weight_map.keys())
33
34 single_file = model_path / MODEL_SAFETENSORS_FILE
35 if single_file.exists():
36 try:
37 with safe_open(single_file, framework="pt", device="cpu") as f:
38 return list(f.keys())
39 except Exception as e:
40 print(f"Error reading {single_file}: {e}")
41 sys.exit(1)
42
43 print(f"Error: No safetensors files found in {model_path}")
44 sys.exit(1)
45
46
47def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
48 weight_map = get_weight_map(model_path)
49
50 if weight_map is not None:
51 return weight_map.get(tensor_name)
52
53 single_file = model_path / MODEL_SAFETENSORS_FILE
54 if single_file.exists():
55 return single_file.name
56
57 return None
58
59
60def normalize_tensor_name(tensor_name: str) -> str:
61 normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
62 normalized = re.sub(r'\.\d+$', '.#', normalized)
63 return normalized
64
65
66def list_all_tensors(model_path: Path, unique: bool = False):
67 tensor_names = get_all_tensor_names(model_path)
68
69 if unique:
70 seen = set()
71 for tensor_name in sorted(tensor_names):
72 normalized = normalize_tensor_name(tensor_name)
73 if normalized not in seen:
74 seen.add(normalized)
75 print(normalized)
76 else:
77 for tensor_name in sorted(tensor_names):
78 print(tensor_name)
79
80
81def print_tensor_info(model_path: Path, tensor_name: str):
82 tensor_file = find_tensor_file(model_path, tensor_name)
83
84 if tensor_file is None:
85 print(f"Error: Could not find tensor '{tensor_name}' in model index")
86 print(f"Model path: {model_path}")
87 sys.exit(1)
88
89 file_path = model_path / tensor_file
90
91 try:
92 with safe_open(file_path, framework="pt", device="cpu") as f:
93 if tensor_name in f.keys():
94 tensor_slice = f.get_slice(tensor_name)
95 shape = tensor_slice.get_shape()
96 print(f"Tensor: {tensor_name}")
97 print(f"File: {tensor_file}")
98 print(f"Shape: {shape}")
99 else:
100 print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
101 sys.exit(1)
102
103 except FileNotFoundError:
104 print(f"Error: The file '{file_path}' was not found.")
105 sys.exit(1)
106 except Exception as e:
107 print(f"An error occurred: {e}")
108 sys.exit(1)
109
110
111def main():
112 parser = argparse.ArgumentParser(
113 description="Print tensor information from a safetensors model"
114 )
115 parser.add_argument(
116 "tensor_name",
117 nargs="?", # optional (if --list is used for example)
118 help="Name of the tensor to inspect"
119 )
120 parser.add_argument(
121 "-m", "--model-path",
122 type=Path,
123 help="Path to the model directory (default: MODEL_PATH environment variable)"
124 )
125 parser.add_argument(
126 "-l", "--list",
127 action="store_true",
128 help="List unique tensor patterns in the model (layer numbers replaced with #)"
129 )
130
131 args = parser.parse_args()
132
133 model_path = args.model_path
134 if model_path is None:
135 model_path_str = os.environ.get("MODEL_PATH")
136 if model_path_str is None:
137 print("Error: --model-path not provided and MODEL_PATH environment variable not set")
138 sys.exit(1)
139 model_path = Path(model_path_str)
140
141 if not model_path.exists():
142 print(f"Error: Model path does not exist: {model_path}")
143 sys.exit(1)
144
145 if not model_path.is_dir():
146 print(f"Error: Model path is not a directory: {model_path}")
147 sys.exit(1)
148
149 if args.list:
150 list_all_tensors(model_path, unique=True)
151 else:
152 if args.tensor_name is None:
153 print("Error: tensor_name is required when not using --list")
154 sys.exit(1)
155 print_tensor_info(model_path, args.tensor_name)
156
157
158if __name__ == "__main__":
159 main()