Skip to content

Commit

Permalink
refactor: save role in config
Browse files Browse the repository at this point in the history
feat: support multi backend
  • Loading branch information
freesrz93 committed Aug 15, 2024
1 parent 1c3eaf2 commit 398a720
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 119 deletions.
29 changes: 13 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,18 @@ import (
"github.com/sashabaranov/go-openai"
)

var (
Client *C
)

func InitClient() {
cfg := openai.DefaultConfig(Config.APIKey)
cfg.BaseURL = Config.BaseURL
Client = &C{Client: openai.NewClientWithConfig(cfg)}
func NewClient(opt *BackendOption) *Client {
cfg := openai.DefaultConfig(opt.APIKey)
cfg.BaseURL = opt.BaseURL
return &Client{Client: openai.NewClientWithConfig(cfg)}
}

type C struct {
type Client struct {
*BackendOption
*openai.Client
}

func (c *C) Stream(s *Session, input string) error {
func (c *Client) Stream(s *Session, input string) error {
defer Pln()

s.Append(openai.ChatCompletionMessage{
Expand All @@ -33,13 +30,13 @@ func (c *C) Stream(s *Session, input string) error {
// https://platform.openai.com/docs/api-reference/chat
stream, err := c.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
Messages: s.Get(),
Model: Config.Model,
MaxTokens: Config.MaxTokens,
Temperature: Config.Temperature,
TopP: Config.TopP,
Model: c.Model,
MaxTokens: c.MaxTokens,
Temperature: c.Temperature,
TopP: c.TopP,
Stream: true,
PresencePenalty: Config.PresencePenalty,
FrequencyPenalty: Config.FrequencyPenalty,
PresencePenalty: c.PresencePenalty,
FrequencyPenalty: c.FrequencyPenalty,
})
if err != nil {
return err
Expand Down
44 changes: 32 additions & 12 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ var (
)

type config struct {
DefaultBackend string `yaml:"default_backend"`
Editor string `yaml:"editor"`
EditorArg string `yaml:"editor_arg"`
Backends map[string]*BackendOption `yaml:"backends"`
Roles map[string]Role `yaml:"roles"`
}

type BackendOption struct {
Description string `yaml:"description"`
DefaultRole string `yaml:"default_role"`
BaseURL string `yaml:"url"`
APIKey string `yaml:"api_key"`
Model string `yaml:"model"`
Expand All @@ -23,8 +33,6 @@ type config struct {
TopP float32 `yaml:"top_p"`
FrequencyPenalty float32 `yaml:"frequency_penalty"`
PresencePenalty float32 `yaml:"presence_penalty"`
Editor string `yaml:"editor"`
EditorArg string `yaml:"editor_arg"`
}

func (c *config) String() string {
Expand All @@ -37,16 +45,28 @@ func (c *config) String() string {

func newDefault() *config {
return &config{
BaseURL: "https://api.openai.com/v1",
APIKey: "",
Model: "gpt-4o-mini",
MaxTokens: 4096,
Temperature: 0.5,
TopP: 1.0,
FrequencyPenalty: 0,
PresencePenalty: 0,
Editor: "code",
EditorArg: "%path",
DefaultBackend: backendOpenai,
Editor: "code",
EditorArg: "%path",
Backends: map[string]*BackendOption{
backendOpenai: {
Description: "",
DefaultRole: defaultRole,
BaseURL: "https://api.openai.com/v1",
APIKey: "",
Model: "gpt-4o-mini",
MaxTokens: 4096,
Temperature: 0.5,
TopP: 1.0,
FrequencyPenalty: 0,
PresencePenalty: 0,
},
},
Roles: map[string]Role{
defaultRole: {
Description: "",
Prompt: defaultPrompt,
}},
}
}

Expand Down
59 changes: 37 additions & 22 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"bufio"
"errors"
"fmt"
"io"
"os"
"strings"
Expand All @@ -21,6 +23,7 @@ var (
showVersion bool
sessionName string
roleName string
backendName string
newRole bool
)

Expand All @@ -46,8 +49,8 @@ func root(cmd *cobra.Command, args []string) {
Pln()
}
case listRoles:
for _, role := range ListRoles() {
P(role.String())
for name, r := range Config.Roles {
P(fmt.Sprintf("name: %s\ndescription: %s\nprompt: %s\n", name, r.Description, r.Prompt))
Pln()
}
case newRole:
Expand All @@ -61,17 +64,11 @@ func root(cmd *cobra.Command, args []string) {
}
}

func handleSession(cmd *cobra.Command, args []string) error {
r, err := GetRole(roleName)
if err != nil {
return err
}

s, err := GetSession(sessionName)
func handleSession(_ *cobra.Command, args []string) error {
c, s, err := initClient()
if err != nil {
return err
}
s.UseRole(r)

if showHistory {
P(s.String())
Expand All @@ -89,19 +86,43 @@ func handleSession(cmd *cobra.Command, args []string) error {
if interactive {
P(AIPrefix)
}
err = Client.Stream(s, input)
err = c.Stream(s, input)
if err != nil {
return err
}
}

if interactive {
interactiveMode(input, s)
interactiveMode(c, s, input)
return nil
}
return nil
}

func initClient() (*Client, *Session, error) {
if backendName == "" {
backendName = Config.DefaultBackend
}
opt, ok := Config.Backends[backendName]
if !ok {
return nil, nil, errors.New("backend not exist")
}
if roleName == "" {
roleName = opt.DefaultRole
}
r, err := GetRole(roleName)
if err != nil {
return nil, nil, err
}
client := NewClient(opt)
s, err := GetSession(sessionName)
if err != nil {
return nil, nil, err
}
s.UseRole(r)
return client, s, nil
}

func createRole() {
scanner := bufio.NewScanner(os.Stdin)
P("Please input role name:\n")
Expand Down Expand Up @@ -129,7 +150,7 @@ func createRole() {
P("Role " + name + " created!")
}

func interactiveMode(input string, s *Session) {
func interactiveMode(client *Client, s *Session, input string) {
scanner := bufio.NewScanner(os.Stdin)
P(UserPrefix)
for scanner.Scan() {
Expand All @@ -138,7 +159,7 @@ func interactiveMode(input string, s *Session) {
break
}
P(AIPrefix)
err := Client.Stream(s, input)
err := client.Stream(s, input)
if err != nil {
PFatal(err)
}
Expand All @@ -155,7 +176,8 @@ func init() {
rootCmd.PersistentFlags().BoolVarP(&interactive, "interactive", "i", false, "Use interactive mode. (default: false)")
rootCmd.PersistentFlags().BoolVarP(&showHistory, "history", "h", false, "Show session history. (default: false)")
rootCmd.PersistentFlags().StringVarP(&sessionName, "session", "s", tempSession, "Create or retrieve a session. If not set, create a temp session that won't be saved.")
rootCmd.PersistentFlags().StringVarP(&roleName, "role", "r", defaultRole, "Specify the role to be used. Only valid for a new or temp session.")
rootCmd.PersistentFlags().StringVarP(&roleName, "role", "r", "", "Specify the role to be used. Only valid for a new or temp session.")
rootCmd.PersistentFlags().StringVarP(&backendName, "backend", "b", "", "Specify the backend to be used.")
rootCmd.PersistentFlags().BoolVarP(&newRole, "new-role", "n", false, "Create a new role.")
rootCmd.PersistentFlags().BoolP("help", "", false, "Show command usage.")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Show app version.")
Expand All @@ -166,14 +188,7 @@ func init() {
if err := os.MkdirAll(SessionDir, os.ModePerm); err != nil {
PFatal(err)
}
if err := os.MkdirAll(RoleDir, os.ModePerm); err != nil {
PFatal(err)
}
LoadCfg()
InitClient()
if err := CreateDefaultRole(); err != nil {
PFatal(err)
}
}

func main() {
Expand Down
72 changes: 9 additions & 63 deletions role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,11 @@ package main

import (
"errors"
"fmt"
"os"
"path/filepath"

"github.com/sashabaranov/go-openai"
"gopkg.in/yaml.v3"
)

var RoleDir = filepath.Join(CfgDir, roleDir)

func CreateDefaultRole() error {
_, err := GetRole(defaultRole)
if err != nil {
return CreateRole(defaultRole, "", defaultPrompt)
}
return nil
}

type Role struct {
Name string
Description string
Prompt string
}
Expand All @@ -33,61 +18,22 @@ func (r *Role) ToMsg() openai.ChatCompletionMessage {
}
}

func (r *Role) String() string {
return fmt.Sprintf("name: %s\ndescription: %s\nprompt: %s\n", r.Name, r.Description, r.Prompt)
}

func GetRole(name string) (*Role, error) {
path := filepath.Join(RoleDir, safeName(name))
b, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.New("role not exist")
}
return nil, err
}
var r Role
err = yaml.Unmarshal(b, &r)
if err != nil {
return nil, err
r, ok := Config.Roles[name]
if !ok {
return nil, errors.New("role not exist")
}
return &r, nil
}

func CreateRole(name, desc, prompt string) error {
path := filepath.Join(RoleDir, safeName(name))
r := &Role{
Name: name,
_, ok := Config.Roles[name]
if ok {
return errors.New("role already exist")
}
Config.Roles[name] = Role{
Description: desc,
Prompt: prompt,
}
b, err := yaml.Marshal(r)
if err != nil {
return err
}
return os.WriteFile(path, b, 0644)
}

func ListRoles() []Role {
entries, err := os.ReadDir(RoleDir)
if err != nil {
return nil
}
res := make([]Role, 0)
for _, e := range entries {
if e.IsDir() {
continue
}
bytes, err := os.ReadFile(filepath.Join(RoleDir, e.Name()))
if err != nil {
continue
}
var r Role
err = yaml.Unmarshal(bytes, &r)
if err != nil {
continue
}
res = append(res, r)
}
return res
return SaveCfg()
}
8 changes: 2 additions & 6 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
const (
configFile = "config.yaml"
sessionDir = "sessions"
roleDir = "roles"
tempSession = "temp"
defaultRole = "default"
backendOpenai = "openai"
defaultPrompt = "You are a polymath. Your role is to synthesize accurate information from various domains while offering insightful analysis and explanations. When responding, strive for clarity and depth, and encourage further inquiry by providing context and related concepts."

AIPrefix = "Assistant: "
Expand All @@ -41,11 +41,7 @@ func Pln() {
P("\n")
}

func PErr(v any) {
P("error: " + fmt.Sprint(v))
}

func PFatal(v any) {
PErr(v)
P("error: " + fmt.Sprint(v))
os.Exit(1)
}

0 comments on commit 398a720

Please sign in to comment.