@@ -226,4 +226,108 @@ messages:
226226
227227 require .Contains (t , out .String (), reply )
228228 })
229+
230+ t .Run ("cli flags override params set in the prompt.yaml file" , func (t * testing.T ) {
231+ // Begin setup:
232+ const yamlBody = `
233+ name: Example Prompt
234+ description: Example description
235+ model: openai/example-model
236+ modelParameters:
237+ maxTokens: 300
238+ temperature: 0.8
239+ topP: 0.9
240+ messages:
241+ - role: system
242+ content: System message
243+ - role: user
244+ content: User message
245+ `
246+ tmp , err := os .CreateTemp (t .TempDir (), "*.prompt.yaml" )
247+ require .NoError (t , err )
248+ _ , err = tmp .WriteString (yamlBody )
249+ require .NoError (t , err )
250+ require .NoError (t , tmp .Close ())
251+
252+ client := azuremodels .NewMockClient ()
253+ modelSummary := & azuremodels.ModelSummary {
254+ Name : "example-model" ,
255+ Publisher : "openai" ,
256+ Task : "chat-completion" ,
257+ }
258+ modelSummary2 := & azuremodels.ModelSummary {
259+ Name : "example-model-4o-mini-plus" ,
260+ Publisher : "openai" ,
261+ Task : "chat-completion" ,
262+ }
263+
264+ client .MockListModels = func (ctx context.Context ) ([]* azuremodels.
265+ ModelSummary , error ) {
266+ return []* azuremodels.ModelSummary {modelSummary , modelSummary2 }, nil
267+ }
268+
269+ var capturedReq azuremodels.ChatCompletionOptions
270+ reply := "hello"
271+ chatCompletion := azuremodels.ChatCompletion {
272+ Choices : []azuremodels.ChatChoice {{
273+ Message : & azuremodels.ChatChoiceMessage {
274+ Content : util .Ptr (reply ),
275+ Role : util .Ptr (string (azuremodels .ChatMessageRoleAssistant )),
276+ },
277+ }},
278+ }
279+
280+ client .MockGetChatCompletionStream = func (ctx context.Context , opt azuremodels.ChatCompletionOptions ) (* azuremodels.ChatCompletionResponse , error ) {
281+ capturedReq = opt
282+ return & azuremodels.ChatCompletionResponse {
283+ Reader : sse .NewMockEventReader ([]azuremodels.ChatCompletion {chatCompletion }),
284+ }, nil
285+ }
286+
287+ out := new (bytes.Buffer )
288+ cfg := command .NewConfig (out , out , client , true , 100 )
289+ runCmd := NewRunCommand (cfg )
290+
291+ // End setup.
292+ // ---
293+ // We're finally ready to start making assertions.
294+
295+ // Test case 1: with no flags, the model params come from the YAML file
296+ runCmd .SetArgs ([]string {
297+ "--file" , tmp .Name (),
298+ })
299+
300+ _ , err = runCmd .ExecuteC ()
301+ require .NoError (t , err )
302+
303+ require .Equal (t , "openai/example-model" , capturedReq .Model )
304+ require .Equal (t , 300 , * capturedReq .MaxTokens )
305+ require .Equal (t , 0.8 , * capturedReq .Temperature )
306+ require .Equal (t , 0.9 , * capturedReq .TopP )
307+
308+ require .Equal (t , "System message" , * capturedReq .Messages [0 ].Content )
309+ require .Equal (t , "User message" , * capturedReq .Messages [1 ].Content )
310+
311+ // Hooray!
312+ // Test case 2: values from flags override the params from the YAML file
313+ runCmd = NewRunCommand (cfg )
314+ runCmd .SetArgs ([]string {
315+ "openai/example-model-4o-mini-plus" ,
316+ "--file" , tmp .Name (),
317+ "--max-tokens" , "150" ,
318+ "--temperature" , "0.1" ,
319+ "--top-p" , "0.3" ,
320+ })
321+
322+ _ , err = runCmd .ExecuteC ()
323+ require .NoError (t , err )
324+
325+ require .Equal (t , "openai/example-model-4o-mini-plus" , capturedReq .Model )
326+ require .Equal (t , 150 , * capturedReq .MaxTokens )
327+ require .Equal (t , 0.1 , * capturedReq .Temperature )
328+ require .Equal (t , 0.3 , * capturedReq .TopP )
329+
330+ require .Equal (t , "System message" , * capturedReq .Messages [0 ].Content )
331+ require .Equal (t , "User message" , * capturedReq .Messages [1 ].Content )
332+ })
229333}
0 commit comments