1import argparse
2import glob
3import os
4import torch
5from safetensors import safe_open
6from safetensors.torch import save_file
7from typing import Any, ContextManager, cast
8
9# Function to determine if file is a SafeTensor file
10def is_safetensor_file(file_path):
11 return file_path.endswith('.safetensors')
12
13
14# Unified loading function
15def load_model(file_path):
16 if is_safetensor_file(file_path):
17 tensors = {}
18 with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
19 for key in f.keys():
20 tensors[key] = f.get_tensor(key).clone()
21 # output shape
22 print(f"{key} : {tensors[key].shape}")
23 return tensors, 'safetensor'
24 else:
25 return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
26
27
28# Unified saving function
29def save_model(model, file_path, file_type):
30 if file_type == 'safetensor':
31 # safe_save(model, file_path)
32 save_file(model, file_path)
33 else:
34 torch.save(model, file_path)
35
36# Helpers to match weight names from specific components or
37# determine if a saved shard contains that component
38def is_vision_tower(weight_name):
39 return (
40 weight_name.startswith("model.vision_tower") or
41 weight_name.startswith("vit.") or
42 weight_name.startswith("vision_tower")
43 )
44
45def is_newline(weight_name):
46 return (
47 weight_name.startswith("model.image_newline") or
48 weight_name.startswith("image_newline")
49 )
50
51def is_mm_projector(weight_name):
52 return (
53 weight_name.startswith("model.mm_projector") or
54 weight_name.startswith("vision_proj.") or
55 weight_name.startswith("multi_modal_projector")
56 )
57
58def newline_criteria(checkpoint):
59 return any(is_newline(k) for k in checkpoint.keys())
60
61def proj_criteria(checkpoint):
62 return any(is_mm_projector(k) for k in checkpoint.keys())
63
64# Adapted function to clean vision tower from checkpoint
65def clean_vision_tower_from_checkpoint(checkpoint_path):
66 checkpoint, file_type = load_model(checkpoint_path)
67 # file_type = 'pytorch'
68 model_path = os.path.dirname(checkpoint_path)
69 print(f"Searching for vision tower tensors in {checkpoint_path}")
70 clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
71
72 if len(clip_tensors) > 0:
73 print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
74 # Adapted for file type
75 clip_path = os.path.join(model_path, "llava.clip")
76
77 if os.path.exists(clip_path):
78 print(f"Loading existing llava.clip from {clip_path}")
79 existing_clip, _ = load_model(clip_path)
80 else:
81 print(f"Creating new llava.clip at {clip_path}")
82 existing_clip = {}
83 # Update existing_clip with new tensors, avoid duplicates
84 for name in clip_tensors:
85 simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
86 print(f"Adding {simple_name} to llava.clip")
87 if simple_name not in existing_clip:
88 existing_clip[simple_name] = checkpoint[name]
89
90 # Save the updated clip tensors back to llava.clip
91 save_model(existing_clip, clip_path, 'pytorch')
92
93 # Remove the tensors from the original checkpoint
94 for name in clip_tensors:
95 del checkpoint[name]
96
97 checkpoint_path = checkpoint_path
98 return True
99 return False
100
101def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
102 newline_checkpoint_path = None
103 projector_checkpoint_path = None
104
105 for path in checkpoint_paths:
106 checkpoint, _ = load_model(path)
107 if newline_criteria(checkpoint) and newline_checkpoint_path is None:
108 newline_checkpoint_path = path
109 if projector(checkpoint):
110 projector_checkpoint_path = path
111
112 return newline_checkpoint_path, projector_checkpoint_path
113
114
115# Command-line interface setup
116ap = argparse.ArgumentParser()
117ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
118ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
119args = ap.parse_args()
120
121if args.clean_vision_tower:
122 # Generalized to handle both PyTorch and SafeTensors models
123 model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
124 # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
125 checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
126 for projector_checkpoint_path in checkpoint_paths:
127 print(f"Cleaning {projector_checkpoint_path}")
128 if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
129 print(f"No vision tower found in {projector_checkpoint_path}")
130 # we break once none is found, so far all models append them at the end
131 # break
132 print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
133
134# Now we look for the projector in the last checkpoint
135model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
136checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
137# last_checkpoint_path = checkpoint_paths[0]
138# first_checkpoint_path = checkpoint_paths[-1]
139newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
140
141print(f"Taking projector from {projector_checkpoint_path}")
142first_mm_tensors = []
143first_checkpoint = None
144if newline_checkpoint_path is not None:
145 print(f"Taking newline from {newline_checkpoint_path}")
146 first_checkpoint, file_type = load_model(newline_checkpoint_path)
147 first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
148
149# Load the checkpoint
150mm_tensors = []
151last_checkpoint = None
152if projector_checkpoint_path is not None:
153 last_checkpoint, file_type = load_model(projector_checkpoint_path)
154 mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
155
156if len(mm_tensors) == 0:
157 if last_checkpoint is not None:
158 for k, v in last_checkpoint.items():
159 print(k)
160 print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
161 print("No tensors found. Is this a LLaVA model?")
162 exit()
163
164print(f"Found {len(mm_tensors)} tensors to extract.")
165print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
166# projector = {name: checkpoint.[name].float() for name in mm_tensors}
167projector = {}
168for name in mm_tensors:
169 assert last_checkpoint is not None
170 projector[name] = last_checkpoint[name].float()
171for name in first_mm_tensors:
172 assert first_checkpoint is not None
173 projector[name] = first_checkpoint[name].float()
174
175if len(projector) > 0:
176 save_model(projector, f"{args.model}/llava.projector", 'pytorch')
177
178print("Done!")
179print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
180print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")