深圳效果好的免費(fèi)網(wǎng)站建設(shè)seo知識總結(jié)
實現(xiàn)當(dāng)用戶在GUI中輸入問題(例如“劉邦”)且輸出的答案被標(biāo)記為不正確時,自動從百度百科中搜索相關(guān)內(nèi)容并顯示在GUI中的功能,我們需要對現(xiàn)有的代碼進(jìn)行一些修改。以下是完整的代碼,包括對XihuaChatbotGUI類的修改以及新增的功能:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup# 獲取項目根目錄
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 數(shù)據(jù)集類
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳過無效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳過無效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳過無效項 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 獲取數(shù)據(jù)加載器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定義
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 訓(xùn)練函數(shù)
def train(model, data_loader, optimizer, criterion, device, progress_var=None):model.train()total_loss = 0.0num_batches = len(data_loader)for batch_idx, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()if progress_var:progress_var.set((batch_idx + 1) / num_batches * 100)except Exception as e:logging.warning(f"跳過無效批次: {e}")return total_loss / len(data_loader)# 主訓(xùn)練函數(shù)
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'使用設(shè)備: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):model.load_state_dict(torch.load(model_path, map_location=device))logging.info("加載現(xiàn)有模型")else:logging.info("沒有找到現(xiàn)有模型,將使用預(yù)訓(xùn)練模型")optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 30for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型訓(xùn)練完成并保存")# 網(wǎng)絡(luò)搜索函數(shù)
def search_baidu(query):url = f"https://www.baidu.com/s?wd={query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')results = soup.find_all('div', class_='c-abstract')if results:return results[0].get_text().strip()return "沒有找到相關(guān)信息"# 百度百科搜索函數(shù)
def search_baidu_baike(query):url = f"https://baike.baidu.com/item/{query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')meta_description = soup.find('meta', attrs={'name': 'description'})if meta_description:return meta_description['content']return "沒有找到相關(guān)信息"# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天機(jī)器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加載訓(xùn)練數(shù)據(jù)集以便在獲取答案時使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))# 歷史記錄self.history = []self.create_widgets()def create_widgets(self):# 設(shè)置樣式style = ttk.Style()style.theme_use('clam')# 頂部框架top_frame = ttk.Frame(self.root)top_frame.pack(pady=10)self.question_label = ttk.Label(top_frame, text="問題:", font=("Arial", 12))self.question_label.grid(row=0, column=0, padx=10)self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))self.question_entry.grid(row=0, column=1, padx=10)self.answer_button = ttk.Button(top_frame, text="獲取回答", command=self.get_answer, style='TButton')self.answer_button.grid(row=0, column=2, padx=10)# 中部框架middle_frame = ttk.Frame(self.root)middle_frame.pack(pady=10)self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')self.chat_text.grid(row=0, column=0, padx=10, pady=10)self.chat_text.tag_configure("user", justify='right', foreground='blue')self.chat_text.tag_configure("xihua", justify='left', foreground='green')# 底部框架bottom_frame = ttk.Frame(self.root)bottom_frame.pack(pady=10)self.correct_button = ttk.Button(bottom_frame, text="準(zhǔn)確", command=self.mark_correct, style='TButton')self.correct_button.grid(row=0, column=0, padx=10)self.incorrect_button = ttk.Button(bottom_frame, text="不準(zhǔn)確", command=self.mark_incorrect, style='TButton')self.incorrect_button.grid(row=0, column=1, padx=10)self.train_button = ttk.Button(bottom_frame, text="訓(xùn)練模型", command=self.train_model, style='TButton')self.train_button.grid(row=0, column=2, padx=10)self.retrain_button = ttk.Button(bottom_frame, text="重新訓(xùn)練模型", command=lambda: self.train_model(retrain=True), style='TButton')self.retrain_button.grid(row=0, column=3, padx=10)self.progress_var = tk.DoubleVar()self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))self.log_text.grid(row=2, column=0, columnspan=4, pady=10)self.evaluate_button = ttk.Button(bottom_frame, text="評估模型", command=self.evaluate_model, style='TButton')self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)self.history_button = ttk.Button(bottom_frame, text="查看歷史記錄", command=self.view_history, style='TButton')self.history_button.grid(row=3, column=1, padx=10, pady=10)self.save_history_button = ttk.Button(bottom_frame, text="保存歷史記錄", command=self.save_history, style='TButton')self.save_history_button.grid(row=3, column=2, padx=10, pady=10)def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("輸入錯誤", "請輸入問題")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "羲和回答"else:answer_type = "零回答"specific_answer = self.get_specific_answer(question, answer_type)self.chat_text.insert(tk.END, f"用戶: {question}\n", "user")self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")# 添加到歷史記錄self.history.append({'question': question,'answer_type': answer_type,'specific_answer': specific_answer,'accuracy': None # 初始狀態(tài)為未評價})def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的問題best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "羲和回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "這個我也不清楚,你問問零吧"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳過無效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳過無效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device))logging.info("加載現(xiàn)有模型")else:logging.info("沒有找到現(xiàn)有模型,將使用預(yù)訓(xùn)練模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件選擇錯誤", "請選擇一個有效的數(shù)據(jù)文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加載已訓(xùn)練的模型權(quán)重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 30for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}\n')self.log_text.see(tk.END)torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型訓(xùn)練完成并保存")self.log_text.insert(tk.END, "模型訓(xùn)練完成并保存\n")self.log_text.see(tk.END)messagebox.showinfo("訓(xùn)練完成", "模型訓(xùn)練完成并保存")except Exception as e:logging.error(f"模型訓(xùn)練失敗: {e}")self.log_text.insert(tk.END, f"模型訓(xùn)練失敗: {e}\n")self.log_text.see(tk.END)messagebox.showerror("訓(xùn)練失敗", f"模型訓(xùn)練失敗: {e}")def evaluate_model(self):# 這里可以添加模型評估的邏輯messagebox.showinfo("評估結(jié)果", "模型評估功能暫未實現(xiàn)")def mark_correct(self):if self.history:self.history[-1]['accuracy'] = Truemessagebox.showinfo("評價成功", "您認(rèn)為這次回答是準(zhǔn)確的")def mark_incorrect(self):if self.history:self.history[-1]['accuracy'] = Falsequestion = self.history[-1]['question']baike_answer = self.search_baidu_baike(question)self.chat_text.insert(tk.END, f"百度百科結(jié)果: {baike_answer}\n", "xihua")messagebox.showinfo("評價成功", "您認(rèn)為這次回答是不準(zhǔn)確的")def search_baidu_baike(self, query):return search_baidu_baike(query)def view_history(self):history_window = tk.Toplevel(self.root)history_window.title("歷史記錄")history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))history_text.pack(padx=10, pady=10)for entry in self.history:history_text.insert(tk.END, f"問題: {entry['question']}\n")history_text.insert(tk.END, f"回答類型: {entry['answer_type']}\n")history_text.insert(tk.END, f"具體回答: {entry['specific_answer']}\n")if entry['accuracy'] is None:history_text.insert(tk.END, "評價: 未評價\n")elif entry['accuracy']:history_text.insert(tk.END, "評價: 準(zhǔn)確\n")else:history_text.insert(tk.END, "評價: 不準(zhǔn)確\n")history_text.insert(tk.END, "-" * 50 + "\n")def save_history(self):file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])if not file_path:returnwith open(file_path, 'w') as f:json.dump(self.history, f, ensure_ascii=False, indent=4)messagebox.showinfo("保存成功", "歷史記錄已保存到文件")# 主函數(shù)
if __name__ == "__main__":# 啟動GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()
主要修改點(diǎn):
增加百度百科搜索函數(shù):search_baidu_baike函數(shù)用于從百度百科中搜索問題的相關(guān)信息。
修改mark_incorrect方法:當(dāng)用戶標(biāo)記回答為不正確時,調(diào)用search_baidu_baike函數(shù)獲取百度百科的結(jié)果,并將其顯示在GUI的Text組件中。
文件結(jié)構(gòu):
main.py:主程序文件,包含所有代碼。
logs/:日志文件存儲目錄。
models/:模型權(quán)重文件存儲目錄。
data/:訓(xùn)練數(shù)據(jù)文件存儲目錄。
運(yùn)行步驟:
確保安裝了所有依賴庫,如torch, transformers, requests, beautifulsoup4等。
將訓(xùn)練數(shù)據(jù)文件放在data/目錄下。
運(yùn)行main.py啟動GUI。
這樣,當(dāng)用戶在GUI中輸入問題并標(biāo)記回答為不正確時,程序會自動從百度百科中搜索相關(guān)信息并顯示在GUI中。