1# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
2# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
3#
4# TODO: this script is LLM-generated and probably very inefficient and should be rewritten
5
6import torch
7import json
8import os
9import sys
10import re
11
12from safetensors.torch import save_file
13
14# default
15model_path = './model.pt'
16
17# read from CLI
18if len(sys.argv) > 1:
19 model_path = sys.argv[1]
20
21# get the directory of the input model
22path_dst = os.path.dirname(model_path)
23
24print(f"Loading model from {model_path}")
25
26model = torch.load(model_path, map_location='cpu')
27
28#print(model)
29
30# print all keys
31for key in model.keys():
32 print(key)
33 if key == 'hyper_parameters':
34 #print(model[key])
35 # dump as json pretty
36 print(json.dumps(model[key], indent=4))
37 #if key != 'state_dict' and key != 'optimizer_states':
38 # print(model[key])
39
40# Check if the loaded model is a state_dict or a model instance
41if isinstance(model, torch.nn.Module):
42 state_dict = model.state_dict()
43else:
44 state_dict = model
45
46# Print the structure of the state_dict to understand its format
47print("State dictionary keys:")
48for key in state_dict.keys():
49 print(key)
50
51# Ensure the state_dict is flat and contains only torch.Tensor objects
52def flatten_state_dict(state_dict, parent_key='', sep='.'):
53 items = []
54 items_new = []
55
56 for k, v in state_dict.items():
57 new_key = f"{parent_key}{sep}{k}" if parent_key else k
58 if isinstance(v, torch.Tensor):
59 items.append((new_key, v))
60 elif isinstance(v, dict):
61 items.extend(flatten_state_dict(v, new_key, sep=sep).items())
62 return dict(items)
63
64 size_total_mb = 0
65
66 for key, value in list(items):
67 # keep only what we need for inference
68 if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
69 not key.startswith('state_dict.backbone.') and \
70 not key.startswith('state_dict.head.out'):
71 print('Skipping key: ', key)
72 continue
73
74 new_key = key
75
76 new_key = new_key.replace('state_dict.', '')
77 new_key = new_key.replace('pos_net', 'posnet')
78
79 # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
80 if new_key.startswith("backbone.posnet."):
81 match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
82 if match:
83 new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
84
85 # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
86 if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
87 new_key = "backbone.embedding.weight"
88
89 # these are the only rows used
90 # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
91 if new_key.endswith("norm.scale.weight"):
92 new_key = new_key.replace("norm.scale.weight", "norm.weight")
93 value = value[0]
94
95 if new_key.endswith("norm.shift.weight"):
96 new_key = new_key.replace("norm.shift.weight", "norm.bias")
97 value = value[0]
98
99 if new_key.endswith("gamma"):
100 new_key = new_key.replace("gamma", "gamma.weight")
101
102 # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
103 if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
104 value = value.unsqueeze(1)
105
106 if new_key.endswith("dwconv.bias"):
107 value = value.unsqueeze(1)
108
109 size_mb = value.element_size() * value.nelement() / (1024 * 1024)
110 print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
111
112 size_total_mb += size_mb
113
114 #print(key, '->', new_key, ': ', value)
115 #print(key, '->', new_key)
116
117 items_new.append((new_key, value))
118
119 print(f"Total size: {size_total_mb:8.2f} MB")
120
121 return dict(items_new)
122
123flattened_state_dict = flatten_state_dict(state_dict)
124
125
126# Convert the model to the safetensors format
127output_path = path_dst + '/model.safetensors'
128save_file(flattened_state_dict, output_path)
129
130print(f"Model has been successfully converted and saved to {output_path}")
131
132# Calculate the total size of the .safetensors file
133total_size = os.path.getsize(output_path)
134
135# Create the weight map
136weight_map = {
137 "model.safetensors": ["*"] # Assuming all weights are in one file
138}
139
140# Create metadata for the index.json file
141metadata = {
142 "total_size": total_size,
143 "weight_map": weight_map
144}
145
146# Save the metadata to index.json
147index_path = path_dst + '/index.json'
148with open(index_path, 'w') as f:
149 json.dump(metadata, f, indent=4)
150
151print(f"Metadata has been saved to {index_path}")
152
153config = {
154 "architectures": [
155 "WavTokenizerDec"
156 ],
157 "hidden_size": 1282,
158 "n_embd_features": 512,
159 "n_ff": 2304,
160 "vocab_size": 4096,
161 "n_head": 1,
162 "layer_norm_epsilon": 1e-6,
163 "group_norm_epsilon": 1e-6,
164 "group_norm_groups": 32,
165 "max_position_embeddings": 8192, # ?
166 "n_layer": 12,
167 "posnet": {
168 "n_embd": 768,
169 "n_layer": 6
170 },
171 "convnext": {
172 "n_embd": 768,
173 "n_layer": 12
174 },
175}
176
177with open(path_dst + '/config.json', 'w') as f:
178 json.dump(config, f, indent=4)
179
180print(f"Config has been saved to {path_dst + 'config.json'}")