Skip to content

Commit fd53c79

Browse files
authored
Merge pull request #113 from FlowTestAI/anthropic-model
feat: added backend support for anthropic claude via aws bedrock
2 parents 81590ad + 92a0bab commit fd53c79

15 files changed

Lines changed: 3163 additions & 1591 deletions

File tree

.changeset/lovely-beers-promise.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'flowtestai': minor
3+
---
4+
5+
Add support for anthropic claude hosted on bedrock

packages/flowtest-electron/package.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727
},
2828
"dependencies": {
2929
"@apidevtools/swagger-parser": "^10.1.0",
30+
"@aws-crypto/sha256-js": "^5.2.0",
31+
"@aws-sdk/client-bedrock": "^3.583.0",
32+
"@aws-sdk/client-bedrock-runtime": "^3.583.0",
33+
"@aws-sdk/credential-provider-node": "^3.583.0",
34+
"@aws-sdk/types": "^3.577.0",
35+
"@smithy/eventstream-codec": "^3.0.0",
36+
"@smithy/protocol-http": "^4.0.0",
37+
"@smithy/signature-v4": "^3.0.0",
38+
"@smithy/util-utf8": "^3.0.0",
3039
"axios": "^1.6.7",
3140
"chokidar": "^3.6.0",
3241
"dotenv": "^16.4.5",
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
const BedrockClaudeGenerate = require('./models/bedrock_claude');
2+
const OpenAIGenerate = require('./models/openai');
3+
4+
class FlowtestAI {
5+
async generate(collection, user_instruction, model) {
6+
if (model.name === 'OPENAI') {
7+
const available_functions = await this.get_available_functions(collection);
8+
const openai = new OpenAIGenerate();
9+
const functions = await openai.filter_functions(available_functions, user_instruction, model.apiKey);
10+
return await openai.process_user_instruction(functions, user_instruction, model.apiKey);
11+
} else if (model.name === 'BEDROCK_CLAUDE') {
12+
const available_functions = await this.get_available_functions(collection);
13+
const bedrock_claude = new BedrockClaudeGenerate(model.apiKey);
14+
const functions = await bedrock_claude.filter_functions(available_functions, user_instruction);
15+
return await bedrock_claude.process_user_instruction(functions, user_instruction);
16+
} else {
17+
throw Error(`Model ${model.name} not supported`);
18+
}
19+
}
20+
21+
async get_available_functions(collection) {
22+
let functions = [];
23+
Object.entries(collection['paths']).map(([path, methods], index) => {
24+
Object.entries(methods).map(([method, spec], index1) => {
25+
const function_name = spec['operationId'];
26+
27+
const desc = spec['description'] || spec['summary'] || '';
28+
29+
let schema = { type: 'object', properties: {} };
30+
31+
let req_body = undefined;
32+
if (spec['requestBody']) {
33+
if (spec['requestBody']['content']) {
34+
if (spec['requestBody']['content']['application/json']) {
35+
if (spec['requestBody']['content']['application/json']['schema']) {
36+
req_body = spec['requestBody']['content']['application/json']['schema'];
37+
}
38+
}
39+
}
40+
}
41+
42+
if (req_body != undefined) {
43+
schema['properties']['requestBody'] = req_body;
44+
}
45+
46+
const params = spec['parameters'] ? spec['parameters'] : [];
47+
const param_properties = {};
48+
if (params.length > 0) {
49+
for (const param of params) {
50+
if (param['schema']) {
51+
param_properties[param['name']] = param['schema'];
52+
}
53+
}
54+
schema['properties']['parameters'] = {
55+
type: 'object',
56+
properties: param_properties,
57+
};
58+
}
59+
60+
const f = {
61+
type: 'function',
62+
function: { name: function_name, description: desc, parameters: schema },
63+
};
64+
65+
if (this.isCyclic(f)) {
66+
functions.push({
67+
type: 'function',
68+
function: { name: function_name, description: desc, parameters: {} },
69+
});
70+
} else {
71+
functions.push(f);
72+
}
73+
});
74+
});
75+
76+
return functions;
77+
}
78+
79+
isCyclic(obj) {
80+
var keys = [];
81+
var stack = [];
82+
var stackSet = new Set();
83+
var detected = false;
84+
85+
function detect(obj, key) {
86+
if (obj && typeof obj != 'object') {
87+
return false;
88+
}
89+
90+
if (stackSet.has(obj)) {
91+
// it's cyclic! Print the object and its locations.
92+
var oldindex = stack.indexOf(obj);
93+
var l1 = keys.join('.') + '.' + key;
94+
var l2 = keys.slice(0, oldindex + 1).join('.');
95+
//console.log('CIRCULAR: ' + l1 + ' = ' + l2 + ' = ' + obj);
96+
//console.log(obj);
97+
detected = true;
98+
return;
99+
}
100+
101+
keys.push(key);
102+
stack.push(obj);
103+
stackSet.add(obj);
104+
for (var k in obj) {
105+
//dive on the object's children
106+
if (Object.prototype.hasOwnProperty.call(obj, k)) {
107+
detect(obj[k], k);
108+
}
109+
}
110+
111+
keys.pop();
112+
stack.pop();
113+
stackSet.delete(obj);
114+
return;
115+
}
116+
117+
detect(obj, 'obj', keys, stack, stackSet, detected);
118+
return detected;
119+
}
120+
}
121+
122+
module.exports = FlowtestAI;
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
const { BedrockChat } = require('@langchain/community/chat_models/bedrock');
2+
const { HumanMessage, SystemMessage, BaseMessage } = require('@langchain/core/messages');
3+
const { BedrockEmbeddings } = require('@langchain/community/embeddings/bedrock');
4+
const { MemoryVectorStore } = require('langchain/vectorstores/memory');
5+
6+
class BedrockClaudeGenerate {
7+
constructor(creds) {
8+
this.model = new BedrockChat({
9+
model: 'anthropic.claude-3-sonnet-20240229-v1:0',
10+
region: 'us-west-2',
11+
// endpointUrl: "custom.amazonaws.com",
12+
credentials: creds,
13+
modelKwargs: {
14+
anthropic_version: 'bedrock-2023-05-31',
15+
},
16+
});
17+
18+
this.embeddings = new BedrockEmbeddings({
19+
region: 'us-west-2',
20+
credentials: creds,
21+
model: 'amazon.titan-embed-text-v2:0', // Default value
22+
});
23+
}
24+
25+
async filter_functions(functions, instruction) {
26+
const documents = functions.map((f) => {
27+
const { parameters, ...fDescription } = f.function;
28+
return JSON.stringify(fDescription);
29+
});
30+
31+
const vectorStore = await MemoryVectorStore.fromTexts(documents, [], this.embeddings);
32+
// 128 (max no of functions accepted by openAI function calling)
33+
const retrievedDocuments = await vectorStore.similaritySearch(instruction, 10);
34+
var selectedFunctions = [];
35+
retrievedDocuments.forEach((document) => {
36+
const pDocument = JSON.parse(document.pageContent);
37+
const findF = functions.find(
38+
(f) => f.function.name === pDocument.name && f.function.description === pDocument.description,
39+
);
40+
if (findF) {
41+
selectedFunctions = selectedFunctions.concat(findF);
42+
}
43+
});
44+
return selectedFunctions;
45+
}
46+
47+
async process_user_instruction(functions, instruction) {
48+
//console.log(functions.map((f) => f.function.name));
49+
// Define the function call format
50+
const fn = `{"name": "function_name"}`;
51+
52+
// Prepare the function string for the system prompt
53+
const fnStr = functions.map((f) => JSON.stringify(f)).join('\n');
54+
55+
// Define the system prompt
56+
const systemPrompt = `
57+
You are a helpful assistant with access to the following functions:
58+
59+
${fnStr}
60+
61+
To use these functions respond with, only output function names, ignore arguments needed by those functions:
62+
63+
<multiplefunctions>
64+
<functioncall> ${fn} </functioncall>
65+
<functioncall> ${fn} </functioncall>
66+
...
67+
</multiplefunctions>
68+
69+
Edge cases you must handle:
70+
- If there are multiple functions that can fullfill user request, list them all.
71+
- If there are no functions that match the user request, you will respond politely that you cannot help.
72+
- If the user has not provided all information to execute the function call, choose the best possible set of values. Only, respond with the information requested and nothing else.
73+
- If asked something that cannot be determined with the user's request details, respond that it is not possible to fulfill the request and explain why.
74+
`;
75+
76+
// Prepare the messages for the language model
77+
const messages = [new SystemMessage({ content: systemPrompt }), new HumanMessage({ content: instruction })];
78+
79+
// Invoke the language model and get the completion
80+
const completion = await this.model.invoke(messages);
81+
const content = completion.content.trim();
82+
83+
// Extract function calls from the completion
84+
const extractedFunctions = this.extractFunctionCalls(content);
85+
86+
console.log(extractedFunctions);
87+
88+
return extractedFunctions;
89+
}
90+
91+
extractFunctionCalls(completion) {
92+
let content = typeof completion === 'string' ? completion : completion.content;
93+
94+
// Multiple functions lookup
95+
const mfnPattern = /<multiplefunctions>(.*?)<\/multiplefunctions>/s;
96+
const mfnMatch = content.match(mfnPattern);
97+
98+
// Single function lookup
99+
const singlePattern = /<functioncall>(.*?)<\/functioncall>/s;
100+
const singleMatch = content.match(singlePattern);
101+
102+
let functions = [];
103+
104+
if (!mfnMatch && !singleMatch) {
105+
// No function calls found
106+
return null;
107+
} else if (mfnMatch) {
108+
// Multiple function calls found
109+
const multiplefn = mfnMatch[1];
110+
const fnMatches = [...multiplefn.matchAll(/<functioncall>(.*?)<\/functioncall>/gs)];
111+
for (let fnMatch of fnMatches) {
112+
const fnText = fnMatch[1].replace(/\\/g, '');
113+
try {
114+
functions.push(JSON.parse(fnText));
115+
} catch {
116+
// Ignore invalid JSON
117+
}
118+
}
119+
} else {
120+
// Single function call found
121+
const fnText = singleMatch[1].replace(/\\/g, '');
122+
try {
123+
functions.push(JSON.parse(fnText));
124+
} catch {
125+
// Ignore invalid JSON
126+
}
127+
}
128+
return functions;
129+
}
130+
}
131+
132+
module.exports = BedrockClaudeGenerate;
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
const OpenAI = require('openai');
2+
const { MemoryVectorStore } = require('langchain/vectorstores/memory');
3+
const { OpenAIEmbeddings } = require('@langchain/openai');
4+
5+
const SYSTEM_MESSAGE = `You are a helpful assistant. \
6+
Respond to the following prompt by using function_call and then summarize actions. \
7+
If a user request is ambiguous, choose the best response possible.`;
8+
9+
// Maximum number of function calls allowed to prevent infinite or lengthy loops
10+
const MAX_CALLS = 10;
11+
12+
class OpenAIGenerate {
13+
async filter_functions(functions, instruction, apiKey) {
14+
const documents = functions.map((f) => {
15+
const { parameters, ...fDescription } = f.function;
16+
return JSON.stringify(fDescription);
17+
});
18+
19+
const vectorStore = await MemoryVectorStore.fromTexts(
20+
documents,
21+
[],
22+
new OpenAIEmbeddings({
23+
openAIApiKey: apiKey,
24+
}),
25+
);
26+
27+
// 128 (max no of functions accepted by openAI function calling)
28+
const retrievedDocuments = await vectorStore.similaritySearch(instruction, 10);
29+
var selectedFunctions = [];
30+
retrievedDocuments.forEach((document) => {
31+
const pDocument = JSON.parse(document.pageContent);
32+
const findF = functions.find(
33+
(f) => f.function.name === pDocument.name && f.function.description === pDocument.description,
34+
);
35+
if (findF) {
36+
selectedFunctions = selectedFunctions.concat(findF);
37+
}
38+
});
39+
40+
return selectedFunctions;
41+
}
42+
43+
async get_openai_response(functions, messages, apiKey) {
44+
const openai = new OpenAI({
45+
apiKey,
46+
});
47+
48+
return await openai.chat.completions.create({
49+
model: 'gpt-4', //gpt-3.5-turbo-16k-0613
50+
tools: functions,
51+
tool_choice: 'auto', // "auto" means the model can pick between generating a message or calling a function.
52+
temperature: 0,
53+
messages: messages,
54+
});
55+
}
56+
57+
async process_user_instruction(functions, instruction, apiKey) {
58+
//console.log(functions.map((f) => f.function.name));
59+
let result = [];
60+
let num_calls = 0;
61+
const messages = [
62+
{ content: SYSTEM_MESSAGE, role: 'system' },
63+
{ content: instruction, role: 'user' },
64+
];
65+
66+
while (num_calls < MAX_CALLS) {
67+
const response = await this.get_openai_response(functions, messages, apiKey);
68+
const message = response['choices'][0]['message'];
69+
70+
if (message.tool_calls) {
71+
messages.push(message);
72+
message.tool_calls.map((tool_call) => {
73+
console.log('Function call #: ', num_calls + 1);
74+
console.log(JSON.stringify(tool_call));
75+
76+
// We'll simply add a message to simulate successful function call.
77+
messages.push({
78+
role: 'tool',
79+
content: 'success',
80+
tool_call_id: tool_call.id,
81+
});
82+
result.push(tool_call.function);
83+
84+
num_calls += 1;
85+
});
86+
} else {
87+
console.log('Message: ');
88+
console.log(message['content']);
89+
break;
90+
}
91+
}
92+
93+
if (num_calls >= MAX_CALLS) {
94+
console.log('Reached max chained function calls: ', MAX_CALLS);
95+
}
96+
97+
return result;
98+
}
99+
}
100+
101+
module.exports = OpenAIGenerate;

packages/flowtest-electron/src/ipc/collection.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const createFile = require('../utils/filemanager/createfile');
1414
const updateFile = require('../utils/filemanager/updatefile');
1515
const deleteFile = require('../utils/filemanager/deletefile');
1616
const readFile = require('../utils/filemanager/readfile');
17-
const FlowtestAI = require('../utils/flowtestai');
17+
const FlowtestAI = require('../ai/flowtestai');
1818
const { stringify, parse } = require('flatted');
1919
const { deserialize, serialize } = require('../utils/flowparser/parser');
2020

0 commit comments

Comments
 (0)