Skip to content

Commit

Permalink
internal/llm: add ContentGenerator.SetTemperature
Browse files Browse the repository at this point in the history
Add a method to ContentGenerator that allows changing the default
temperature.

Use it in labeleval to see if it reduces variability.

(It does, but it doesn't eliminate it.)

Change-Id: I85dd3ad16af1133627b48fd469ccdf7983f8aeea
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/635284
Reviewed-by: Tatiana Bradley <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
jba committed Dec 13, 2024
1 parent 6dd16b5 commit 58a1d42
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
7 changes: 6 additions & 1 deletion internal/devtools/cmd/labeleval/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,13 @@ Issues:

func newGeminiClient(ctx context.Context, lg *slog.Logger) (*gemini.Client, error) {
sdb := secret.Netrc()
return gemini.NewClient(ctx, lg, sdb, http.DefaultClient,
c, err := gemini.NewClient(ctx, lg, sdb, http.DefaultClient,
gemini.DefaultEmbeddingModel, gemini.DefaultGenerativeModel)
if err != nil {
return nil, err
}
c.SetTemperature(0)
return c, nil
}

func readJSONFile(filename string, p any) error {
Expand Down
17 changes: 16 additions & 1 deletion internal/gcp/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type Client struct {
slog *slog.Logger
genai *genai.Client
embeddingModel, generativeModel string
temperature float32 // negative means use default
}

const (
Expand Down Expand Up @@ -89,7 +90,13 @@ func NewClient(ctx context.Context, lg *slog.Logger, sdb secret.DB, hc *http.Cli
return nil, err
}

return &Client{slog: lg, genai: ai, embeddingModel: embeddingModel, generativeModel: generativeModel}, nil
return &Client{
slog: lg,
genai: ai,
embeddingModel: embeddingModel,
generativeModel: generativeModel,
temperature: -1,
}, nil
}

// withKey returns a new http.Client that is the same as hc
Expand Down Expand Up @@ -150,6 +157,11 @@ func (c *Client) Model() string {
return c.generativeModel
}

// SetTemperature sets the temperature of the client's generative model.
func (c *Client) SetTemperature(t float32) {
c.temperature = t
}

// GenerateContent returns the model's response for the prompt parts,
// implementing [llm.ContentGenerator.GenerateContent].
func (c *Client) GenerateContent(ctx context.Context, schema *llm.Schema, promptParts []llm.Part) (string, error) {
Expand Down Expand Up @@ -198,6 +210,9 @@ func (c *Client) model(mimeType string, schema *genai.Schema) *genai.GenerativeM
model.SetCandidateCount(1)
model.ResponseMIMEType = mimeType
model.ResponseSchema = schema
if c.temperature >= 0 {
model.SetTemperature(c.temperature)
}
return model
}

Expand Down
2 changes: 2 additions & 0 deletions internal/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,6 @@ type ContentGenerator interface {
// and one or more prompt parts.
// If the JSON schema is nil, GenerateContent outputs a plain text response.
GenerateContent(ctx context.Context, schema *Schema, parts []Part) (string, error)
// SetTemperature changes the temperature of the model.
SetTemperature(float32)
}
6 changes: 6 additions & 0 deletions internal/llm/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ type echo struct{}
// Implements [ContentGenerator.Model].
func (echo) Model() string { return "echo" }

// Implements [ContentGenerator.SetTemperature] as a no-op.
func (echo) SetTemperature(float32) {}

// GenerateContent echoes the prompts.
// If the schema is non-nil, the output is wrapped as a JSON object with a
// single value "prompt", ignoring the actual schema contents (for testing).
Expand Down Expand Up @@ -159,6 +162,9 @@ func (g *generator) Model() string {
return g.model
}

// SetTemperature implements [ContentGenerator.SetTemperature] as a no-op.
func (g *generator) SetTemperature(float32) {}

// GenerateContent implements [ContentGenerator.GenerateContent].
func (g *generator) GenerateContent(ctx context.Context, schema *Schema, promptParts []Part) (string, error) {
if g.generateContent == nil {
Expand Down

0 comments on commit 58a1d42

Please sign in to comment.