-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathutils.py
More file actions
80 lines (68 loc) · 2.86 KB
/
utils.py
File metadata and controls
80 lines (68 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from binary_score_models import GradeAnswer,GradeDocuments,GradeHallucinations
import os
from dotenv import load_dotenv
load_dotenv()
import re
import json
def clean_text(text):
# Remove <think> blocks (including content)
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# Remove any standalone <think> tags just in case
text = text.replace('<think>', '').replace('</think>', '')
return text
class Nodeoutputs:
def __init__(self, api_key, model, prompts_file):
os.environ["NVIDIA_API_KEY"] = api_key
self.llm = ChatNVIDIA( api_key=api_key, model=model)
self.prompts = self.load_prompts(prompts_file)
self.setup_prompts()
def load_prompts(self, prompts_file):
with open(prompts_file, 'r') as file:
return json.load(file)
def setup_prompts(self):
self.prompt = ChatPromptTemplate.from_messages(
[
("system", self.prompts["qa_system_prompt"]),
("user", self.prompts["qa_user_prompt"])
]
)
self.rag_chain = self.prompt | self.llm | StrOutputParser()
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", self.prompts["re_write_system"]),
("human", self.prompts["re_write_human"]),
]
)
self.question_rewriter = re_write_prompt | self.llm | StrOutputParser()
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", self.prompts["grade_system"]),
("human", self.prompts["grade_human"]),
]
)
self.retrieval_grader = grade_prompt | self.llm | StrOutputParser() | clean_text | JsonOutputParser()
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", self.prompts["hallucination_system"]),
("human", self.prompts["hallucination_human"]),
]
)
self.hallucination_grader = hallucination_prompt | self.llm | StrOutputParser() | clean_text | JsonOutputParser()
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", self.prompts["answer_system"]),
("human", self.prompts["answer_human"]),
]
)
self.answer_grader = answer_prompt | self.llm | StrOutputParser() | clean_text | JsonOutputParser()
def format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
# Usage
# Access the API key from environment variables
api_key = os.getenv('API_KEY')
model = "nvidia/llama-3.3-nemotron-super-49b-v1.5"
prompts_file = "prompt.json"
automation = Nodeoutputs(api_key, model, prompts_file)