1# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
  2# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
  3#
  4# TODO: this script is LLM-generated and probably very inefficient and should be rewritten
  5
  6import torch
  7import json
  8import os
  9import sys
 10import re
 11
 12from safetensors.torch import save_file
 13
 14# default
 15model_path = './model.pt'
 16
 17# read from CLI
 18if len(sys.argv) > 1:
 19    model_path = sys.argv[1]
 20
 21# get the directory of the input model
 22path_dst = os.path.dirname(model_path)
 23
 24print(f"Loading model from {model_path}")
 25
 26model = torch.load(model_path, map_location='cpu')
 27
 28#print(model)
 29
 30# print all keys
 31for key in model.keys():
 32    print(key)
 33    if key == 'hyper_parameters':
 34        #print(model[key])
 35        # dump as json pretty
 36        print(json.dumps(model[key], indent=4))
 37    #if key != 'state_dict' and key != 'optimizer_states':
 38    #    print(model[key])
 39
 40# Check if the loaded model is a state_dict or a model instance
 41if isinstance(model, torch.nn.Module):
 42    state_dict = model.state_dict()
 43else:
 44    state_dict = model
 45
 46# Print the structure of the state_dict to understand its format
 47print("State dictionary keys:")
 48for key in state_dict.keys():
 49    print(key)
 50
 51# Ensure the state_dict is flat and contains only torch.Tensor objects
 52def flatten_state_dict(state_dict, parent_key='', sep='.'):
 53    items = []
 54    items_new = []
 55
 56    for k, v in state_dict.items():
 57        new_key = f"{parent_key}{sep}{k}" if parent_key else k
 58        if isinstance(v, torch.Tensor):
 59            items.append((new_key, v))
 60        elif isinstance(v, dict):
 61            items.extend(flatten_state_dict(v, new_key, sep=sep).items())
 62            return dict(items)
 63
 64    size_total_mb = 0
 65
 66    for key, value in list(items):
 67        # keep only what we need for inference
 68        if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
 69           not key.startswith('state_dict.backbone.') and \
 70           not key.startswith('state_dict.head.out'):
 71               print('Skipping key: ', key)
 72               continue
 73
 74        new_key = key
 75
 76        new_key = new_key.replace('state_dict.', '')
 77        new_key = new_key.replace('pos_net', 'posnet')
 78
 79        # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
 80        if new_key.startswith("backbone.posnet."):
 81            match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
 82            if match:
 83               new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
 84
 85        # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
 86        if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
 87            new_key = "backbone.embedding.weight"
 88
 89        # these are the only rows used
 90        # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
 91        if new_key.endswith("norm.scale.weight"):
 92            new_key = new_key.replace("norm.scale.weight", "norm.weight")
 93            value = value[0]
 94
 95        if new_key.endswith("norm.shift.weight"):
 96            new_key = new_key.replace("norm.shift.weight", "norm.bias")
 97            value = value[0]
 98
 99        if new_key.endswith("gamma"):
100            new_key = new_key.replace("gamma", "gamma.weight")
101
102        # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
103        if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
104            value = value.unsqueeze(1)
105
106        if new_key.endswith("dwconv.bias"):
107            value = value.unsqueeze(1)
108
109        size_mb = value.element_size() * value.nelement() / (1024 * 1024)
110        print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
111
112        size_total_mb += size_mb
113
114        #print(key, '->', new_key, ': ', value)
115        #print(key, '->', new_key)
116
117        items_new.append((new_key, value))
118
119    print(f"Total size: {size_total_mb:8.2f} MB")
120
121    return dict(items_new)
122
123flattened_state_dict = flatten_state_dict(state_dict)
124
125
126# Convert the model to the safetensors format
127output_path = path_dst + '/model.safetensors'
128save_file(flattened_state_dict, output_path)
129
130print(f"Model has been successfully converted and saved to {output_path}")
131
132# Calculate the total size of the .safetensors file
133total_size = os.path.getsize(output_path)
134
135# Create the weight map
136weight_map = {
137    "model.safetensors": ["*"]  # Assuming all weights are in one file
138}
139
140# Create metadata for the index.json file
141metadata = {
142    "total_size": total_size,
143    "weight_map": weight_map
144}
145
146# Save the metadata to index.json
147index_path = path_dst + '/index.json'
148with open(index_path, 'w') as f:
149    json.dump(metadata, f, indent=4)
150
151print(f"Metadata has been saved to {index_path}")
152
153config = {
154    "architectures": [
155        "WavTokenizerDec"
156    ],
157    "hidden_size": 1282,
158    "n_embd_features": 512,
159    "n_ff": 2304,
160    "vocab_size": 4096,
161    "n_head": 1,
162    "layer_norm_epsilon": 1e-6,
163    "group_norm_epsilon": 1e-6,
164    "group_norm_groups": 32,
165    "max_position_embeddings": 8192, # ?
166    "n_layer": 12,
167    "posnet": {
168        "n_embd": 768,
169        "n_layer": 6
170    },
171    "convnext": {
172        "n_embd": 768,
173        "n_layer": 12
174    },
175}
176
177with open(path_dst + '/config.json', 'w') as f:
178    json.dump(config, f, indent=4)
179
180print(f"Config has been saved to {path_dst + 'config.json'}")