Skip to content

Commit fe2d58f

Browse files
authored
Merge pull request #140 from FlowTestAI/support-google-gemini
feat: add support for google gemini
2 parents ae4196a + ad0ee95 commit fe2d58f

File tree

9 files changed

+355
-26
lines changed

9 files changed

+355
-26
lines changed

.changeset/early-colts-approve.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'flowtestai-app': minor
3+
---
4+
5+
Add support for google geminin function calling in generating flow

packages/flowtest-electron/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
"@aws-sdk/client-bedrock-runtime": "^3.583.0",
3434
"@aws-sdk/credential-provider-node": "^3.583.0",
3535
"@aws-sdk/types": "^3.577.0",
36+
"@google/generative-ai": "^0.16.0",
3637
"@langchain/community": "^0.2.19",
38+
"@langchain/google-genai": "^0.0.25",
3739
"@smithy/eventstream-codec": "^3.0.0",
3840
"@smithy/protocol-http": "^4.0.0",
3941
"@smithy/signature-v4": "^3.0.0",

packages/flowtest-electron/src/ai/flowtestai.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const BedrockClaudeGenerate = require('./models/bedrock_claude');
2+
const GeminiGenerate = require('./models/gemini');
23
const OpenAIGenerate = require('./models/openai');
34

45
class FlowtestAI {
@@ -13,6 +14,11 @@ class FlowtestAI {
1314
const bedrock_claude = new BedrockClaudeGenerate(model.apiKey);
1415
const functions = await bedrock_claude.filter_functions(available_functions, user_instruction);
1516
return await bedrock_claude.process_user_instruction(functions, user_instruction);
17+
} else if (model.name === 'GEMINI') {
18+
const available_functions = await this.get_available_functions(collection);
19+
const gemini = new GeminiGenerate(model.apiKey);
20+
const functions = await gemini.filter_functions(available_functions, user_instruction);
21+
return await gemini.process_user_instruction(functions, user_instruction);
1622
} else {
1723
throw Error(`Model ${model.name} not supported`);
1824
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
const { GoogleGenerativeAI } = require('@google/generative-ai');
2+
const { GoogleGenerativeAIEmbeddings } = require('@langchain/google-genai');
3+
const { TaskType } = require('@google/generative-ai');
4+
const { MemoryVectorStore } = require('langchain/vectorstores/memory');
5+
6+
class GeminiGenerate {
7+
constructor(apiKey) {
8+
this.genAI = new GoogleGenerativeAI(apiKey);
9+
10+
this.embeddings = new GoogleGenerativeAIEmbeddings({
11+
apiKey,
12+
model: 'text-embedding-004', // 768 dimensions
13+
taskType: TaskType.RETRIEVAL_DOCUMENT,
14+
title: 'Document title',
15+
});
16+
}
17+
18+
async filter_functions(functions, instruction) {
19+
const documents = functions.map((f) => {
20+
const { parameters, ...fDescription } = f.function;
21+
return JSON.stringify(fDescription);
22+
});
23+
24+
const vectorStore = await MemoryVectorStore.fromTexts(documents, [], this.embeddings);
25+
26+
// 128 (max no of functions accepted by openAI function calling)
27+
const retrievedDocuments = await vectorStore.similaritySearch(instruction, 10);
28+
var selectedFunctions = [];
29+
retrievedDocuments.forEach((document) => {
30+
const pDocument = JSON.parse(document.pageContent);
31+
const findF = functions.find(
32+
(f) => f.function.name === pDocument.name && f.function.description === pDocument.description,
33+
);
34+
if (findF) {
35+
selectedFunctions = selectedFunctions.concat(findF);
36+
}
37+
});
38+
39+
return selectedFunctions;
40+
}
41+
42+
async process_user_instruction(functions, instruction) {
43+
//console.log(functions.map((f) => f.function.name));
44+
// Define the function call format
45+
const fn = `{"name": "function_name"}`;
46+
47+
// Prepare the function string for the system prompt
48+
const fnStr = functions.map((f) => JSON.stringify(f)).join('\n');
49+
50+
// Define the system prompt
51+
const systemPrompt = `
52+
You are a helpful assistant with access to the following functions:
53+
54+
${fnStr}
55+
56+
To use these functions respond with, only output function names, ignore arguments needed by those functions:
57+
58+
<multiplefunctions>
59+
<functioncall> ${fn} </functioncall>
60+
<functioncall> ${fn} </functioncall>
61+
...
62+
</multiplefunctions>
63+
64+
Edge cases you must handle:
65+
- If there are multiple functions that can fullfill user request, list them all.
66+
- If there are no functions that match the user request, you will respond politely that you cannot help.
67+
- If the user has not provided all information to execute the function call, choose the best possible set of values. Only, respond with the information requested and nothing else.
68+
- If asked something that cannot be determined with the user's request details, respond that it is not possible to fulfill the request and explain why.
69+
`;
70+
71+
const model = this.genAI.getGenerativeModel({
72+
model: 'gemini-1.5-pro-latest',
73+
systemInstruction: {
74+
role: 'system',
75+
parts: [{ text: systemPrompt }],
76+
},
77+
});
78+
79+
// Prepare the messages for the language model
80+
81+
const request = {
82+
contents: [{ role: 'user', parts: [{ text: instruction }] }],
83+
};
84+
85+
// Invoke the language model and get the completion
86+
const completion = await model.generateContent(request);
87+
88+
const content = completion.response.candidates[0].content.parts[0].text.trim();
89+
90+
// Extract function calls from the completion
91+
const extractedFunctions = this.extractFunctionCalls(content);
92+
93+
return extractedFunctions;
94+
}
95+
96+
extractFunctionCalls(completion) {
97+
let content = typeof completion === 'string' ? completion : completion.content;
98+
99+
// Multiple functions lookup
100+
const mfnPattern = /<multiplefunctions>(.*?)<\/multiplefunctions>/s;
101+
const mfnMatch = content.match(mfnPattern);
102+
103+
// Single function lookup
104+
const singlePattern = /<functioncall>(.*?)<\/functioncall>/s;
105+
const singleMatch = content.match(singlePattern);
106+
107+
let functions = [];
108+
109+
if (!mfnMatch && !singleMatch) {
110+
// No function calls found
111+
return null;
112+
} else if (mfnMatch) {
113+
// Multiple function calls found
114+
const multiplefn = mfnMatch[1];
115+
const fnMatches = [...multiplefn.matchAll(/<functioncall>(.*?)<\/functioncall>/gs)];
116+
for (let fnMatch of fnMatches) {
117+
const fnText = fnMatch[1].replace(/\\/g, '');
118+
try {
119+
functions.push(JSON.parse(fnText));
120+
} catch {
121+
// Ignore invalid JSON
122+
}
123+
}
124+
} else {
125+
// Single function call found
126+
const fnText = singleMatch[1].replace(/\\/g, '');
127+
try {
128+
functions.push(JSON.parse(fnText));
129+
} catch {
130+
// Ignore invalid JSON
131+
}
132+
}
133+
return functions;
134+
}
135+
}
136+
137+
module.exports = GeminiGenerate;

packages/flowtest-electron/tests/utils/flowtest-ai.test.js

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,23 @@ describe('generate', () => {
4444
const nodeNames = result.map((node) => node.name);
4545
expect(nodeNames).toEqual(['addPet', 'getPetById', 'findPetsByStatus']);
4646
}, 60000);
47+
48+
it('should generate functions using gemini', async () => {
49+
const f = new FlowtestAI();
50+
const USER_INSTRUCTION =
51+
'Add a new pet to the store. \
52+
Then get the created pet. \
53+
Then get pet with status as available.';
54+
//const testYaml = fs.readFileSync('tests/test.yaml', { encoding: 'utf8', flag: 'r' });
55+
let api = await SwaggerParser.validate('tests/test.yaml');
56+
console.log('API name: %s, Version: %s', api.info.title, api.info.version);
57+
const resolvedSpec = (await JsonRefs.resolveRefs(api)).resolved;
58+
59+
let result = await f.generate(resolvedSpec, USER_INSTRUCTION, {
60+
name: 'GEMINI',
61+
apiKey: '',
62+
});
63+
const nodeNames = result.map((node) => node.name);
64+
expect(nodeNames).toEqual(['addPet', 'getPetById', 'findPetsByStatus']);
65+
}, 60000);
4766
});

0 commit comments

Comments
 (0)