blob: bb666159c98d4150210c7fde64aa854058642d77 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
import { modelsStore } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { toast } from 'svelte-sonner';
interface UseModelChangeValidationOptions {
/**
* Function to get required modalities for validation.
* For ChatForm: () => usedModalities() - all messages
* For ChatMessageAssistant: () => getModalitiesUpToMessage(messageId) - messages before
*/
getRequiredModalities: () => ModelModalities;
/**
* Optional callback to execute after successful validation.
* For ChatForm: undefined - just select model
* For ChatMessageAssistant: (modelName) => onRegenerate(modelName)
*/
onSuccess?: (modelName: string) => void;
/**
* Optional callback for rollback on validation failure.
* For ChatForm: (previousId) => selectModelById(previousId)
* For ChatMessageAssistant: undefined - no rollback needed
*/
onValidationFailure?: (previousModelId: string | null) => Promise<void>;
}
export function useModelChangeValidation(options: UseModelChangeValidationOptions) {
const { getRequiredModalities, onSuccess, onValidationFailure } = options;
let previousSelectedModelId: string | null = null;
const isRouter = $derived(isRouterMode());
async function handleModelChange(modelId: string, modelName: string): Promise<boolean> {
try {
// Store previous selection for potential rollback
if (onValidationFailure) {
previousSelectedModelId = modelsStore.selectedModelId;
}
// Load model if not already loaded (router mode only)
let hasLoadedModel = false;
const isModelLoadedBefore = modelsStore.isModelLoaded(modelName);
if (isRouter && !isModelLoadedBefore) {
try {
await modelsStore.loadModel(modelName);
hasLoadedModel = true;
} catch {
toast.error(`Failed to load model "${modelName}"`);
return false;
}
}
// Fetch model props to validate modalities
const props = await modelsStore.fetchModelProps(modelName);
if (props?.modalities) {
const requiredModalities = getRequiredModalities();
// Check if model supports required modalities
const missingModalities: string[] = [];
if (requiredModalities.vision && !props.modalities.vision) {
missingModalities.push('vision');
}
if (requiredModalities.audio && !props.modalities.audio) {
missingModalities.push('audio');
}
if (missingModalities.length > 0) {
toast.error(
`Model "${modelName}" doesn't support required modalities: ${missingModalities.join(', ')}. Please select a different model.`
);
// Unload the model if we just loaded it
if (isRouter && hasLoadedModel) {
try {
await modelsStore.unloadModel(modelName);
} catch (error) {
console.error('Failed to unload incompatible model:', error);
}
}
// Execute rollback callback if provided
if (onValidationFailure && previousSelectedModelId) {
await onValidationFailure(previousSelectedModelId);
}
return false;
}
}
// Select the model (validation passed)
await modelsStore.selectModelById(modelId);
// Execute success callback if provided
if (onSuccess) {
onSuccess(modelName);
}
return true;
} catch (error) {
console.error('Failed to change model:', error);
toast.error('Failed to validate model capabilities');
// Execute rollback callback on error if provided
if (onValidationFailure && previousSelectedModelId) {
await onValidationFailure(previousSelectedModelId);
}
return false;
}
}
return {
handleModelChange
};
}
|