Skip to content

Commit 16482d9

Browse files
garmancschleiden
authored andcommitted
add --org flag to run and eval
1 parent 8286775 commit 16482d9

8 files changed

Lines changed: 46 additions & 23 deletions

File tree

cmd/eval/eval.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ type EvaluationResult struct {
4848
Details string `json:"details,omitempty"`
4949
}
5050

51+
type Organization struct {
52+
Name string `json:"name"`
53+
}
54+
5155
var FailedTests = errors.New("❌ Some tests failed.")
5256

5357
// NewEvalCommand returns a new command to evaluate prompts against models
@@ -66,7 +70,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
6670
6771
Example prompt.yml structure:
6872
name: My Evaluation
69-
model: gpt-4o
73+
model: openai/gpt-4o
7074
testData:
7175
- input: "Hello world"
7276
expected: "Hello there"
@@ -94,6 +98,9 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
9498
return err
9599
}
96100

101+
// Get the org flag
102+
org, _ := cmd.Flags().GetString("org")
103+
97104
// Load the evaluation prompt file
98105
evalFile, err := loadEvaluationPromptFile(promptFilePath)
99106
if err != nil {
@@ -106,6 +113,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
106113
client: cfg.Client,
107114
evalFile: evalFile,
108115
jsonOutput: jsonOutput,
116+
org: org,
109117
}
110118

111119
err = handler.runEvaluation(cmd.Context())
@@ -120,6 +128,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
120128
}
121129

122130
cmd.Flags().Bool("json", false, "Output results in JSON format")
131+
cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor")
123132
return cmd
124133
}
125134

@@ -128,6 +137,7 @@ type evalCommandHandler struct {
128137
client azuremodels.Client
129138
evalFile *prompt.File
130139
jsonOutput bool
140+
org string
131141
}
132142

133143
func loadEvaluationPromptFile(filePath string) (*prompt.File, error) {
@@ -321,7 +331,7 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string]
321331
func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
322332
req := h.evalFile.BuildChatCompletionOptions(messages)
323333

324-
resp, err := h.client.GetChatCompletionStream(ctx, req)
334+
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
325335
if err != nil {
326336
return "", err
327337
}
@@ -460,7 +470,7 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
460470
Stream: false,
461471
}
462472

463-
resp, err := h.client.GetChatCompletionStream(ctx, req)
473+
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
464474
if err != nil {
465475
return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err)
466476
}

cmd/run/run.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
216216
Args: cobra.ArbitraryArgs,
217217
RunE: func(cmd *cobra.Command, args []string) error {
218218
filePath, _ := cmd.Flags().GetString("file")
219+
org, _ := cmd.Flags().GetString("org")
219220
var pf *prompt.File
220221
if filePath != "" {
221222
var err error
@@ -357,7 +358,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
357358
//nolint:gocritic,revive // TODO
358359
defer sp.Stop()
359360

360-
reader, err := cmdHandler.getChatCompletionStreamReader(req)
361+
reader, err := cmdHandler.getChatCompletionStreamReader(req, org)
361362
if err != nil {
362363
return err
363364
}
@@ -408,6 +409,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
408409
cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.")
409410
cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.")
410411
cmd.Flags().String("system-prompt", "", "Prompt the system.")
412+
cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor")
411413

412414
return cmd
413415
}
@@ -522,8 +524,8 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
522524
return modelName, nil
523525
}
524526

525-
func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) {
526-
resp, err := h.client.GetChatCompletionStream(h.ctx, req)
527+
func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) {
528+
resp, err := h.client.GetChatCompletionStream(h.ctx, req, org)
527529
if err != nil {
528530
return nil, err
529531
}

internal/azuremodels/azure_client.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientC
4040
}
4141

4242
// GetChatCompletionStream returns a stream of chat completions using the given options.
43-
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) {
43+
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
4444
// Check for o1 models, which don't support streaming
4545
if req.Model == "o1-mini" || req.Model == "o1-preview" || req.Model == "o1" {
4646
req.Stream = false
@@ -55,7 +55,14 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl
5555

5656
body := bytes.NewReader(bodyBytes)
5757

58-
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body)
58+
var inferenceURL string
59+
if org != "" {
60+
inferenceURL = fmt.Sprintf("%s/orgs/%s/%s", c.cfg.InferenceRoot, org, c.cfg.InferencePath)
61+
} else {
62+
inferenceURL = c.cfg.InferenceRoot + "/" + c.cfg.InferencePath
63+
}
64+
65+
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, inferenceURL, body)
5966
if err != nil {
6067
return nil, err
6168
}

internal/azuremodels/azure_client_config.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
package azuremodels
22

33
const (
4-
defaultInferenceURL = "https://models.github.ai/inference/chat/completions"
4+
defaultInferenceRoot = "https://models.github.ai"
5+
defaultInferencePath = "inference/chat/completions"
56
defaultAzureAiStudioURL = "https://api.catalog.azureml.ms"
67
defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models"
78
)
89

910
// AzureClientConfig represents configurable settings for the Azure client.
1011
type AzureClientConfig struct {
11-
InferenceURL string
12+
InferenceRoot string
13+
InferencePath string
1214
AzureAiStudioURL string
1315
ModelsURL string
1416
}
1517

1618
// NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs.
1719
func NewDefaultAzureClientConfig() *AzureClientConfig {
1820
return &AzureClientConfig{
19-
InferenceURL: defaultInferenceURL,
21+
InferenceRoot: defaultInferenceRoot,
22+
InferencePath: defaultInferencePath,
2023
AzureAiStudioURL: defaultAzureAiStudioURL,
2124
ModelsURL: defaultModelsURL,
2225
}

internal/azuremodels/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import "context"
55
// Client represents a client for interacting with an API about models.
66
type Client interface {
77
// GetChatCompletionStream returns a stream of chat completions using the given options.
8-
GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error)
8+
GetChatCompletionStream(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error)
99
// GetModelDetails returns the details of the specified model in a particular registry.
1010
GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error)
1111
// ListModels returns a list of available models.

internal/azuremodels/mock_client.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ import (
77

88
// MockClient provides a client for interacting with the Azure models API in tests.
99
type MockClient struct {
10-
MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error)
10+
MockGetChatCompletionStream func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error)
1111
MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error)
1212
MockListModels func(context.Context) ([]*ModelSummary, error)
1313
}
1414

1515
// NewMockClient returns a new mock client for stubbing out interactions with the models API.
1616
func NewMockClient() *MockClient {
1717
return &MockClient{
18-
MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) {
18+
MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error) {
1919
return nil, errors.New("GetChatCompletionStream not implemented")
2020
},
2121
MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) {
@@ -28,8 +28,8 @@ func NewMockClient() *MockClient {
2828
}
2929

3030
// GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request.
31-
func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) {
32-
return c.MockGetChatCompletionStream(ctx, opt)
31+
func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
32+
return c.MockGetChatCompletionStream(ctx, opt, org)
3333
}
3434

3535
// GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry.

internal/azuremodels/types.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ type ChatMessage struct {
2626

2727
// ChatCompletionOptions represents available options for a chat completion request.
2828
type ChatCompletionOptions struct {
29-
MaxTokens *int `json:"max_tokens,omitempty"`
30-
Messages []ChatMessage `json:"messages"`
31-
Model string `json:"model"`
32-
Stream bool `json:"stream,omitempty"`
33-
Temperature *float64 `json:"temperature,omitempty"`
34-
TopP *float64 `json:"top_p,omitempty"`
29+
MaxTokens *int `json:"max_tokens,omitempty"`
30+
Messages []ChatMessage `json:"messages"`
31+
Model string `json:"model"`
32+
Stream bool `json:"stream,omitempty"`
33+
Temperature *float64 `json:"temperature,omitempty"`
34+
TopP *float64 `json:"top_p,omitempty"`
35+
Organization *string `json:"organization,omitempty"`
3536
}
3637

3738
// ChatChoiceMessage is a message from a choice in a chat conversation.

internal/azuremodels/unauthenticated_client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func NewUnauthenticatedClient() *UnauthenticatedClient {
1515
}
1616

1717
// GetChatCompletionStream returns an error because this functionality requires authentication.
18-
func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) {
18+
func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
1919
return nil, errors.New("not authenticated")
2020
}
2121

0 commit comments

Comments
 (0)