Switch Model Command (#126)
This commit is contained in:
@@ -25,7 +25,7 @@ export const ollama = new Ollama({
|
||||
const messageHistory: Queue<UserMessage> = new Queue<UserMessage>
|
||||
|
||||
// register all events
|
||||
registerEvents(client, Events, messageHistory, Keys, ollama)
|
||||
registerEvents(client, Events, messageHistory, ollama)
|
||||
|
||||
// Try to log in the client
|
||||
await client.login(Keys.clientToken)
|
||||
|
||||
@@ -8,6 +8,7 @@ import { Capacity } from './capacity.js'
|
||||
import { PrivateThreadCreate } from './threadPrivateCreate.js'
|
||||
import { ClearUserChannelHistory } from './cleanUserChannelHistory.js'
|
||||
import { PullModel } from './pullModel.js'
|
||||
import { SwitchModel } from './switchModel.js'
|
||||
|
||||
export default [
|
||||
ThreadCreate,
|
||||
@@ -18,5 +19,6 @@ export default [
|
||||
Shutoff,
|
||||
Capacity,
|
||||
ClearUserChannelHistory,
|
||||
PullModel
|
||||
PullModel,
|
||||
SwitchModel
|
||||
] as SlashCommand[]
|
||||
@@ -20,6 +20,7 @@ export const PullModel: SlashCommand = {
|
||||
run: async (client: Client, interaction: CommandInteraction) => {
|
||||
// defer reply to avoid timeout
|
||||
await interaction.deferReply()
|
||||
const modelInput: string = interaction.options.get('model-to-pull')!!.value as string
|
||||
|
||||
// fetch channel and message
|
||||
const channel = await client.channels.fetch(interaction.channelId)
|
||||
@@ -28,19 +29,19 @@ export const PullModel: SlashCommand = {
|
||||
try {
|
||||
// call ollama to pull desired model
|
||||
await ollama.pull({
|
||||
model: interaction.options.get('model-to-pull')!!.value as string
|
||||
model: modelInput
|
||||
})
|
||||
} catch (error) {
|
||||
// could not resolve pull or model unfound
|
||||
interaction.editReply({
|
||||
content: `Could not pull/locate the **${interaction.options.get('model-to-pull')!!.value}** model within the [Ollama Model Library](https://ollama.com/library).\n\nPlease check the model library and try again.`
|
||||
content: `Could not pull/locate the **${modelInput}** model within the [Ollama Model Library](https://ollama.com/library).\n\nPlease check the model library and try again.`
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// successful pull
|
||||
interaction.editReply({
|
||||
content: `Successfully added **${interaction.options.get('model-to-pull')!!.value}** into your local model library.`
|
||||
content: `Successfully added **${modelInput}** into your local model library.`
|
||||
})
|
||||
}
|
||||
}
|
||||
75
src/commands/switchModel.ts
Normal file
75
src/commands/switchModel.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { ApplicationCommandOptionType, ChannelType, Client, CommandInteraction } from "discord.js";
|
||||
import { SlashCommand } from "../utils/commands.js";
|
||||
import { ollama } from "../client.js";
|
||||
import { ModelResponse } from "ollama";
|
||||
import { openConfig } from "../utils/index.js";
|
||||
|
||||
export const SwitchModel: SlashCommand = {
|
||||
name: 'switch-model',
|
||||
description: 'switches current model to preferred model to use.',
|
||||
|
||||
// set available user options to pass to the command
|
||||
options: [
|
||||
{
|
||||
name: 'model-to-use',
|
||||
description: 'the name of the model to use',
|
||||
type: ApplicationCommandOptionType.String,
|
||||
required: true
|
||||
}
|
||||
],
|
||||
|
||||
// Switch user preferred model if available in local library
|
||||
run: async (client: Client, interaction: CommandInteraction) => {
|
||||
await interaction.deferReply()
|
||||
|
||||
const modelInput: string = interaction.options.get('model-to-use')!!.value as string
|
||||
|
||||
// fetch channel and message
|
||||
const channel = await client.channels.fetch(interaction.channelId)
|
||||
if (!channel || channel.type !== (ChannelType.PrivateThread && ChannelType.PublicThread && ChannelType.GuildText)) return
|
||||
|
||||
try {
|
||||
// Phase 1: Set the model
|
||||
let switchSuccess = false
|
||||
await ollama.list()
|
||||
.then(response => {
|
||||
for (const model in response.models) {
|
||||
const currentModel: ModelResponse = response.models[model]
|
||||
if (currentModel.name.startsWith(modelInput)) {
|
||||
openConfig(`${interaction.user.username}-config.json`, interaction.commandName, modelInput)
|
||||
|
||||
// successful switch
|
||||
interaction.editReply({
|
||||
content: `Successfully switched to **${modelInput}** as the preferred model for ${interaction.user.username}.`
|
||||
})
|
||||
switchSuccess = true
|
||||
}
|
||||
}
|
||||
})
|
||||
if (switchSuccess) return
|
||||
|
||||
// Phase 2: Try to get it regardless
|
||||
interaction.editReply({
|
||||
content: `Could not find **${modelInput}** in local model library, trying to pull it now...\n\nThis could take a few moments... Please be patient!`
|
||||
})
|
||||
|
||||
await ollama.pull({
|
||||
model: modelInput
|
||||
})
|
||||
|
||||
// set model now that it exists
|
||||
openConfig(`${interaction.user.username}-config.json`, interaction.commandName, modelInput)
|
||||
|
||||
// We got the model!
|
||||
interaction.editReply({
|
||||
content: `Successfully added and set **${modelInput}** as your preferred model.`
|
||||
})
|
||||
} catch (error) {
|
||||
// could not resolve user model switch
|
||||
interaction.editReply({
|
||||
content: `Unable to switch user preferred model to **${modelInput}**.\n\n${error}\n\nPossible solution is to run \`/pull-model ${modelInput}\` and try again.`
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import { getChannelInfo, getServerConfig, getUserConfig, openChannelInfo, openCo
|
||||
*
|
||||
* @param message the message received from the channel
|
||||
*/
|
||||
export default event(Events.MessageCreate, async ({ log, msgHist, tokens, ollama, client }, message) => {
|
||||
export default event(Events.MessageCreate, async ({ log, msgHist, ollama, client }, message) => {
|
||||
const clientId = client.user!!.id
|
||||
const cleanedMessage = clean(message.content, clientId)
|
||||
log(`Message \"${cleanedMessage}\" from ${message.author.tag} in channel/thread ${message.channelId}.`)
|
||||
@@ -49,7 +49,7 @@ export default event(Events.MessageCreate, async ({ log, msgHist, tokens, ollama
|
||||
getUserConfig(`${message.author.username}-config.json`, (config) => {
|
||||
if (config === undefined) {
|
||||
openConfig(`${message.author.username}-config.json`, 'message-style', false)
|
||||
reject(new Error('No User Preferences is set up.\n\nCreating preferences file with \`message-style\` set as \`false\` for regular messages.\nPlease try chatting again.'))
|
||||
reject(new Error('No User Preferences is set up.\n\nCreating preferences file with \`message-style\` set as \`false\` for regular message style.\nPlease try chatting again.'))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,6 +65,9 @@ export default event(Events.MessageCreate, async ({ log, msgHist, tokens, ollama
|
||||
|
||||
// set stream state
|
||||
shouldStream = config.options['message-stream'] as boolean || false
|
||||
|
||||
if (typeof config.options['switch-model'] !== 'string')
|
||||
reject(new Error(`No Model was set. Please set a model by running \`/switch-model <model of choice>\`.\n\nIf you do not have any models. Run \`/pull-model <model name>\`.`))
|
||||
|
||||
resolve(config)
|
||||
})
|
||||
@@ -105,6 +108,7 @@ export default event(Events.MessageCreate, async ({ log, msgHist, tokens, ollama
|
||||
|
||||
// get message attachment if exists
|
||||
const messageAttachment: string[] = await getAttachmentData(message.attachments.first())
|
||||
const model: string = userConfig.options['switch-model']
|
||||
|
||||
// set up new queue
|
||||
msgHist.setQueue(chatMessages)
|
||||
@@ -121,9 +125,9 @@ export default event(Events.MessageCreate, async ({ log, msgHist, tokens, ollama
|
||||
|
||||
// undefined or false, use normal, otherwise use embed
|
||||
if (userConfig.options['message-style'])
|
||||
response = await embedMessage(message, ollama, tokens, msgHist, shouldStream)
|
||||
response = await embedMessage(message, ollama, model, msgHist, shouldStream)
|
||||
else
|
||||
response = await normalMessage(message, ollama, tokens, msgHist, shouldStream)
|
||||
response = await normalMessage(message, ollama, model, msgHist, shouldStream)
|
||||
|
||||
// If something bad happened, remove user query and stop
|
||||
if (response == undefined) { msgHist.pop(); return }
|
||||
|
||||
@@ -2,9 +2,8 @@ import { getEnvVar } from './utils/index.js'
|
||||
|
||||
export const Keys = {
|
||||
clientToken: getEnvVar('CLIENT_TOKEN'),
|
||||
model: getEnvVar('MODEL'),
|
||||
ipAddress: getEnvVar('OLLAMA_IP'),
|
||||
portAddress: getEnvVar('OLLAMA_PORT'),
|
||||
ipAddress: getEnvVar('OLLAMA_IP', '127.0.0.1'), // default ollama ip if none
|
||||
portAddress: getEnvVar('OLLAMA_PORT', '11434'), // default ollama port if none
|
||||
} as const // readonly keys
|
||||
|
||||
export default Keys
|
||||
@@ -3,7 +3,8 @@ import { UserMessage } from './index.js'
|
||||
export interface UserConfiguration {
|
||||
'message-stream'?: boolean,
|
||||
'message-style'?: boolean,
|
||||
'modify-capacity': number
|
||||
'modify-capacity': number,
|
||||
'switch-model': string
|
||||
}
|
||||
|
||||
export interface ServerConfiguration {
|
||||
|
||||
@@ -8,14 +8,6 @@ export { Events } from 'discord.js'
|
||||
export type LogMethod = (...args: unknown[]) => void
|
||||
export type EventKeys = keyof ClientEvents // only wants keys of ClientEvents object
|
||||
|
||||
/**
|
||||
* Tokens to run the bot as intended
|
||||
* @param model chosen model for the ollama to utilize
|
||||
*/
|
||||
export type Tokens = {
|
||||
model: string,
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters to run the chat query
|
||||
* @param model the model to run
|
||||
@@ -44,7 +36,6 @@ export interface EventProps {
|
||||
client: Client
|
||||
log: LogMethod
|
||||
msgHist: Queue<UserMessage>
|
||||
tokens: Tokens,
|
||||
ollama: Ollama
|
||||
}
|
||||
export type EventCallback<T extends EventKeys> = (
|
||||
@@ -67,14 +58,12 @@ export function event<T extends EventKeys>(key: T, callback: EventCallback<T>):
|
||||
* @param client initialized bot client
|
||||
* @param events all the exported events from the index.ts in the events dir
|
||||
* @param msgHist The message history of the bot
|
||||
* @param tokens the passed in environment tokens for the service
|
||||
* @param ollama the initialized ollama instance
|
||||
*/
|
||||
export function registerEvents(
|
||||
client: Client,
|
||||
events: Event[],
|
||||
msgHist: Queue<UserMessage>,
|
||||
tokens: Tokens,
|
||||
ollama: Ollama
|
||||
): void {
|
||||
for (const { key, callback } of events) {
|
||||
@@ -84,7 +73,7 @@ export function registerEvents(
|
||||
|
||||
// Handle Errors, call callback, log errors as needed
|
||||
try {
|
||||
callback({ client, log, msgHist, tokens, ollama }, ...args)
|
||||
callback({ client, log, msgHist, ollama }, ...args)
|
||||
} catch (error) {
|
||||
log('[Uncaught Error]', error)
|
||||
}
|
||||
|
||||
@@ -7,15 +7,13 @@ import { AbortableAsyncIterator } from 'ollama/src/utils.js'
|
||||
/**
|
||||
* Method to send replies as normal text on discord like any other user
|
||||
* @param message message sent by the user
|
||||
* @param tokens tokens to run query
|
||||
* @param model name of model to run query
|
||||
* @param msgHist message history between user and model
|
||||
*/
|
||||
export async function embedMessage(
|
||||
message: Message,
|
||||
ollama: Ollama,
|
||||
tokens: {
|
||||
model: string
|
||||
},
|
||||
model: string,
|
||||
msgHist: Queue<UserMessage>,
|
||||
stream: boolean
|
||||
): Promise<string> {
|
||||
@@ -34,7 +32,7 @@ export async function embedMessage(
|
||||
|
||||
// create params
|
||||
const params: ChatParams = {
|
||||
model: tokens.model,
|
||||
model: model,
|
||||
ollama: ollama,
|
||||
msgHist: msgHist.getItems()
|
||||
}
|
||||
|
||||
@@ -7,15 +7,13 @@ import { AbortableAsyncIterator } from 'ollama/src/utils.js'
|
||||
/**
|
||||
* Method to send replies as normal text on discord like any other user
|
||||
* @param message message sent by the user
|
||||
* @param tokens tokens to run query
|
||||
* @param model name of model to run query
|
||||
* @param msgHist message history between user and model
|
||||
*/
|
||||
export async function normalMessage(
|
||||
message: Message,
|
||||
ollama: Ollama,
|
||||
tokens: {
|
||||
model: string
|
||||
},
|
||||
model: string,
|
||||
msgHist: Queue<UserMessage>,
|
||||
stream: boolean
|
||||
): Promise<string> {
|
||||
@@ -26,7 +24,7 @@ export async function normalMessage(
|
||||
await message.channel.send('Generating Response . . .').then(async sentMessage => {
|
||||
try {
|
||||
const params: ChatParams = {
|
||||
model: tokens.model,
|
||||
model: model,
|
||||
ollama: ollama,
|
||||
msgHist: msgHist.getItems()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user