1#!/usr/bin/env python3
 2
 3import argparse
 4import os
 5import json
 6from safetensors import safe_open
 7from collections import defaultdict
 8
 9parser = argparse.ArgumentParser(description='Process model with specified path')
10parser.add_argument('--model-path', '-m', help='Path to the model')
11args = parser.parse_args()
12
13model_path = os.environ.get('MODEL_PATH', args.model_path)
14if model_path is None:
15    parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
16
17# Check if there's an index file (multi-file model)
18index_path = os.path.join(model_path, "model.safetensors.index.json")
19single_file_path = os.path.join(model_path, "model.safetensors")
20
21if os.path.exists(index_path):
22    # Multi-file model
23    print("Multi-file model detected")
24
25    with open(index_path, 'r') as f:
26        index_data = json.load(f)
27
28    # Get the weight map (tensor_name -> file_name)
29    weight_map = index_data.get("weight_map", {})
30
31    # Group tensors by file for efficient processing
32    file_tensors = defaultdict(list)
33    for tensor_name, file_name in weight_map.items():
34        file_tensors[file_name].append(tensor_name)
35
36    print("Tensors in model:")
37
38    # Process each shard file
39    for file_name, tensor_names in file_tensors.items():
40        file_path = os.path.join(model_path, file_name)
41        print(f"\n--- From {file_name} ---")
42
43        with safe_open(file_path, framework="pt") as f:
44            for tensor_name in sorted(tensor_names):
45                tensor = f.get_tensor(tensor_name)
46                print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
47
48elif os.path.exists(single_file_path):
49    # Single file model (original behavior)
50    print("Single-file model detected")
51
52    with safe_open(single_file_path, framework="pt") as f:
53        keys = f.keys()
54        print("Tensors in model:")
55        for key in sorted(keys):
56            tensor = f.get_tensor(key)
57            print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
58
59else:
60    print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
61    print("Available files:")
62    if os.path.exists(model_path):
63        for item in sorted(os.listdir(model_path)):
64            print(f"  {item}")
65    else:
66        print(f"  Directory {model_path} does not exist")
67    exit(1)