Skip to content

Commit 3bb8a18

Browse files
committed
Refactor runSingleTestWithContext to simplify message handling; replace template variable replacement with regex and improve role assignment logic
1 parent df5d94b commit 3bb8a18

1 file changed

Lines changed: 26 additions & 33 deletions

File tree

cmd/generate/pipeline.go

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package generate
33
import (
44
"encoding/json"
55
"fmt"
6+
"regexp"
67
"strings"
78

89
"github.com/github/gh-models/internal/azuremodels"
9-
"github.com/github/gh-models/pkg/prompt"
1010
"github.com/github/gh-models/pkg/util"
1111
)
1212

@@ -317,72 +317,63 @@ Generate exactly %d diverse test cases:`, nTests,
317317
}
318318

319319
// runSingleTestWithContext runs a single test against a model with context
320-
func (h *generateCommandHandler) runSingleTestWithContext(input, modelName string, context *PromptPexContext) (string, error) {
320+
func (h *generateCommandHandler) runSingleTestWithContext(input string, modelName string, context *PromptPexContext) (string, error) {
321321
// Use the context if provided, otherwise use the stored context
322-
var messages []prompt.Message
323-
if context != nil {
324-
messages = context.Prompt.Messages
325-
} else {
326-
// Fallback to basic sentiment analysis prompt
327-
systemContent := "You are a sentiment analysis expert. Classify the sentiment of the given text."
328-
userContent := "Classify the sentiment of this text as positive, negative, or neutral: {{text}}\n\nRespond with only the sentiment word."
329-
messages = []prompt.Message{
330-
{Role: "system", Content: systemContent},
331-
{Role: "user", Content: userContent},
332-
}
333-
}
322+
messages := context.Prompt.Messages
334323

335324
// Build OpenAI messages from our messages format
336-
var openaiMessages []azuremodels.ChatMessage
337-
for _, msg := range messages {
325+
re := regexp.MustCompile(`\{\{\s*text\s*\}\}`)
326+
openaiMessages := make([]azuremodels.ChatMessage, 0, len(messages))
327+
for i, msg := range messages {
338328
// Replace template variables in content
339-
var content string
340-
if msg.Content != "" {
341-
content = strings.ReplaceAll(msg.Content, "{{text}}", input)
329+
content := msg.Content
330+
if content != "" {
331+
content = re.ReplaceAllString(content, input)
342332
}
343333

344334
// Convert role format
345335
var role azuremodels.ChatMessageRole
346-
if msg.Role == "A" || msg.Role == "assistant" {
336+
switch msg.Role {
337+
case "assistant":
347338
role = azuremodels.ChatMessageRoleAssistant
348-
} else if msg.Role == "system" {
339+
case "system":
349340
role = azuremodels.ChatMessageRoleSystem
350-
} else {
341+
case "user":
351342
role = azuremodels.ChatMessageRoleUser
343+
default:
344+
return "", fmt.Errorf("unknown role: %s", msg.Role)
352345
}
353346

354-
openaiMessages = append(openaiMessages, azuremodels.ChatMessage{
347+
openaiMessages[i] = azuremodels.ChatMessage{
355348
Role: role,
356349
Content: &content,
357-
})
350+
}
358351
}
359352

360353
options := azuremodels.ChatCompletionOptions{
361-
Model: "openai/gpt-4o-mini", // GitHub Models compatible model
354+
Model: modelName,
362355
Messages: openaiMessages,
363356
Temperature: util.Ptr(0.0),
364357
}
365358

366-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
367-
if err != nil {
368-
return "", err
369-
}
370-
completion, err := response.Reader.Read()
359+
result, err := h.callModelWithRetry("tests", options)
371360
if err != nil {
372-
return "", err
361+
return "", fmt.Errorf("failed to run test input: %w", err)
373362
}
374-
result := *completion.Choices[0].Message.Content
375363

376364
return result, nil
377365
}
378366

379367
// generateGroundtruth generates groundtruth outputs using the specified model
380368
func (h *generateCommandHandler) generateGroundtruth(context *PromptPexContext) error {
369+
h.WriteStartBox("Groundtruth")
370+
381371
groundtruthModel := h.options.Models.Groundtruth
372+
382373
h.cfg.WriteToOut("Groundtruth")
383374

384375
for i := range context.Tests {
385-
test := &context.Tests[i]
376+
test := context.Tests[i]
386377

387378
// Generate groundtruth output
388379
output, err := h.runSingleTestWithContext(test.TestInput, *groundtruthModel, context)
@@ -395,6 +386,8 @@ func (h *generateCommandHandler) generateGroundtruth(context *PromptPexContext)
395386
test.GroundtruthModel = groundtruthModel
396387
}
397388

389+
h.WriteEndBox("")
390+
398391
return nil
399392
}
400393

0 commit comments

Comments
 (0)