summaryrefslogtreecommitdiff
path: root/llama.cpp/tools/tts/convert_pt_to_hf.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/tools/tts/convert_pt_to_hf.py')
-rw-r--r--llama.cpp/tools/tts/convert_pt_to_hf.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/llama.cpp/tools/tts/convert_pt_to_hf.py b/llama.cpp/tools/tts/convert_pt_to_hf.py
new file mode 100644
index 0000000..ebd55d9
--- /dev/null
+++ b/llama.cpp/tools/tts/convert_pt_to_hf.py
@@ -0,0 +1,180 @@
+# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
+# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
+#
+# TODO: this script is LLM-generated and probably very inefficient and should be rewritten
+
+import torch
+import json
+import os
+import sys
+import re
+
+from safetensors.torch import save_file
+
+# default
+model_path = './model.pt'
+
+# read from CLI
+if len(sys.argv) > 1:
+ model_path = sys.argv[1]
+
+# get the directory of the input model
+path_dst = os.path.dirname(model_path)
+
+print(f"Loading model from {model_path}")
+
+model = torch.load(model_path, map_location='cpu')
+
+#print(model)
+
+# print all keys
+for key in model.keys():
+ print(key)
+ if key == 'hyper_parameters':
+ #print(model[key])
+ # dump as json pretty
+ print(json.dumps(model[key], indent=4))
+ #if key != 'state_dict' and key != 'optimizer_states':
+ # print(model[key])
+
+# Check if the loaded model is a state_dict or a model instance
+if isinstance(model, torch.nn.Module):
+ state_dict = model.state_dict()
+else:
+ state_dict = model
+
+# Print the structure of the state_dict to understand its format
+print("State dictionary keys:")
+for key in state_dict.keys():
+ print(key)
+
+# Ensure the state_dict is flat and contains only torch.Tensor objects
+def flatten_state_dict(state_dict, parent_key='', sep='.'):
+ items = []
+ items_new = []
+
+ for k, v in state_dict.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, torch.Tensor):
+ items.append((new_key, v))
+ elif isinstance(v, dict):
+ items.extend(flatten_state_dict(v, new_key, sep=sep).items())
+ return dict(items)
+
+ size_total_mb = 0
+
+ for key, value in list(items):
+ # keep only what we need for inference
+ if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
+ not key.startswith('state_dict.backbone.') and \
+ not key.startswith('state_dict.head.out'):
+ print('Skipping key: ', key)
+ continue
+
+ new_key = key
+
+ new_key = new_key.replace('state_dict.', '')
+ new_key = new_key.replace('pos_net', 'posnet')
+
+ # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
+ if new_key.startswith("backbone.posnet."):
+ match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
+ if match:
+ new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
+
+ # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
+ if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
+ new_key = "backbone.embedding.weight"
+
+ # these are the only rows used
+ # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
+ if new_key.endswith("norm.scale.weight"):
+ new_key = new_key.replace("norm.scale.weight", "norm.weight")
+ value = value[0]
+
+ if new_key.endswith("norm.shift.weight"):
+ new_key = new_key.replace("norm.shift.weight", "norm.bias")
+ value = value[0]
+
+ if new_key.endswith("gamma"):
+ new_key = new_key.replace("gamma", "gamma.weight")
+
+ # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
+ if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
+ value = value.unsqueeze(1)
+
+ if new_key.endswith("dwconv.bias"):
+ value = value.unsqueeze(1)
+
+ size_mb = value.element_size() * value.nelement() / (1024 * 1024)
+ print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
+
+ size_total_mb += size_mb
+
+ #print(key, '->', new_key, ': ', value)
+ #print(key, '->', new_key)
+
+ items_new.append((new_key, value))
+
+ print(f"Total size: {size_total_mb:8.2f} MB")
+
+ return dict(items_new)
+
+flattened_state_dict = flatten_state_dict(state_dict)
+
+
+# Convert the model to the safetensors format
+output_path = path_dst + '/model.safetensors'
+save_file(flattened_state_dict, output_path)
+
+print(f"Model has been successfully converted and saved to {output_path}")
+
+# Calculate the total size of the .safetensors file
+total_size = os.path.getsize(output_path)
+
+# Create the weight map
+weight_map = {
+ "model.safetensors": ["*"] # Assuming all weights are in one file
+}
+
+# Create metadata for the index.json file
+metadata = {
+ "total_size": total_size,
+ "weight_map": weight_map
+}
+
+# Save the metadata to index.json
+index_path = path_dst + '/index.json'
+with open(index_path, 'w') as f:
+ json.dump(metadata, f, indent=4)
+
+print(f"Metadata has been saved to {index_path}")
+
+config = {
+ "architectures": [
+ "WavTokenizerDec"
+ ],
+ "hidden_size": 1282,
+ "n_embd_features": 512,
+ "n_ff": 2304,
+ "vocab_size": 4096,
+ "n_head": 1,
+ "layer_norm_epsilon": 1e-6,
+ "group_norm_epsilon": 1e-6,
+ "group_norm_groups": 32,
+ "max_position_embeddings": 8192, # ?
+ "n_layer": 12,
+ "posnet": {
+ "n_embd": 768,
+ "n_layer": 6
+ },
+ "convnext": {
+ "n_embd": 768,
+ "n_layer": 12
+ },
+}
+
+with open(path_dst + '/config.json', 'w') as f:
+ json.dump(config, f, indent=4)
+
+print(f"Config has been saved to {path_dst + 'config.json'}")