1import argparse
  2import glob
  3import os
  4import torch
  5from safetensors import safe_open
  6from safetensors.torch import save_file
  7from typing import Any, ContextManager, cast
  8
  9# Function to determine if file is a SafeTensor file
 10def is_safetensor_file(file_path):
 11    return file_path.endswith('.safetensors')
 12
 13
 14# Unified loading function
 15def load_model(file_path):
 16    if is_safetensor_file(file_path):
 17        tensors = {}
 18        with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
 19            for key in f.keys():
 20                tensors[key] = f.get_tensor(key).clone()
 21                # output shape
 22                print(f"{key} : {tensors[key].shape}")
 23        return tensors, 'safetensor'
 24    else:
 25        return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
 26
 27
 28# Unified saving function
 29def save_model(model, file_path, file_type):
 30    if file_type == 'safetensor':
 31        # safe_save(model, file_path)
 32        save_file(model, file_path)
 33    else:
 34        torch.save(model, file_path)
 35
 36# Helpers to match weight names from specific components or
 37# determine if a saved shard contains that component
 38def is_vision_tower(weight_name):
 39    return (
 40        weight_name.startswith("model.vision_tower") or
 41        weight_name.startswith("vit.") or
 42        weight_name.startswith("vision_tower")
 43    )
 44
 45def is_newline(weight_name):
 46    return (
 47        weight_name.startswith("model.image_newline") or
 48        weight_name.startswith("image_newline")
 49    )
 50
 51def is_mm_projector(weight_name):
 52    return (
 53        weight_name.startswith("model.mm_projector") or
 54        weight_name.startswith("vision_proj.") or
 55        weight_name.startswith("multi_modal_projector")
 56    )
 57
 58def newline_criteria(checkpoint):
 59    return any(is_newline(k) for k in checkpoint.keys())
 60
 61def proj_criteria(checkpoint):
 62    return any(is_mm_projector(k) for k in checkpoint.keys())
 63
 64# Adapted function to clean vision tower from checkpoint
 65def clean_vision_tower_from_checkpoint(checkpoint_path):
 66    checkpoint, file_type = load_model(checkpoint_path)
 67    # file_type = 'pytorch'
 68    model_path = os.path.dirname(checkpoint_path)
 69    print(f"Searching for vision tower tensors in {checkpoint_path}")
 70    clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
 71
 72    if len(clip_tensors) > 0:
 73        print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
 74        # Adapted for file type
 75        clip_path = os.path.join(model_path, "llava.clip")
 76
 77        if os.path.exists(clip_path):
 78            print(f"Loading existing llava.clip from {clip_path}")
 79            existing_clip, _ = load_model(clip_path)
 80        else:
 81            print(f"Creating new llava.clip at {clip_path}")
 82            existing_clip = {}
 83        # Update existing_clip with new tensors, avoid duplicates
 84        for name in clip_tensors:
 85            simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
 86            print(f"Adding {simple_name} to llava.clip")
 87            if simple_name not in existing_clip:
 88                existing_clip[simple_name] = checkpoint[name]
 89
 90        # Save the updated clip tensors back to llava.clip
 91        save_model(existing_clip, clip_path, 'pytorch')
 92
 93        # Remove the tensors from the original checkpoint
 94        for name in clip_tensors:
 95            del checkpoint[name]
 96
 97        checkpoint_path = checkpoint_path
 98        return True
 99    return False
100
101def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
102    newline_checkpoint_path = None
103    projector_checkpoint_path = None
104
105    for path in checkpoint_paths:
106        checkpoint, _ = load_model(path)
107        if newline_criteria(checkpoint) and newline_checkpoint_path is None:
108            newline_checkpoint_path = path
109        if projector(checkpoint):
110            projector_checkpoint_path = path
111
112    return newline_checkpoint_path, projector_checkpoint_path
113
114
115# Command-line interface setup
116ap = argparse.ArgumentParser()
117ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
118ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
119args = ap.parse_args()
120
121if args.clean_vision_tower:
122    # Generalized to handle both PyTorch and SafeTensors models
123    model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
124    # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
125    checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
126    for projector_checkpoint_path in checkpoint_paths:
127        print(f"Cleaning {projector_checkpoint_path}")
128        if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
129            print(f"No vision tower found in {projector_checkpoint_path}")
130            # we break once none is found, so far all models append them at the end
131            # break
132    print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
133
134# Now we look for the projector in the last checkpoint
135model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
136checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
137# last_checkpoint_path = checkpoint_paths[0]
138# first_checkpoint_path = checkpoint_paths[-1]
139newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
140
141print(f"Taking projector from {projector_checkpoint_path}")
142first_mm_tensors = []
143first_checkpoint = None
144if newline_checkpoint_path is not None:
145    print(f"Taking newline from {newline_checkpoint_path}")
146    first_checkpoint, file_type = load_model(newline_checkpoint_path)
147    first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
148
149# Load the checkpoint
150mm_tensors = []
151last_checkpoint = None
152if projector_checkpoint_path is not None:
153    last_checkpoint, file_type = load_model(projector_checkpoint_path)
154    mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
155
156if len(mm_tensors) == 0:
157    if last_checkpoint is not None:
158        for k, v in last_checkpoint.items():
159            print(k)
160    print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
161    print("No tensors found. Is this a LLaVA model?")
162    exit()
163
164print(f"Found {len(mm_tensors)} tensors to extract.")
165print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
166# projector = {name: checkpoint.[name].float() for name in mm_tensors}
167projector = {}
168for name in mm_tensors:
169    assert last_checkpoint is not None
170    projector[name] = last_checkpoint[name].float()
171for name in first_mm_tensors:
172    assert first_checkpoint is not None
173    projector[name] = first_checkpoint[name].float()
174
175if len(projector) > 0:
176    save_model(projector, f"{args.model}/llava.projector", 'pytorch')
177
178print("Done!")
179print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
180print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")