-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot.py
More file actions
67 lines (47 loc) · 1.91 KB
/
chatbot.py
File metadata and controls
67 lines (47 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Chatbot with LLMs
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# For this example, you'll be using facebook/blenderbot-400M-distill
# because it has an open-source license and runs relatively fast.
model_name = "facebook/blenderbot-400M-distill"
# Fetch the model and initialize a tokenizer
# two terms: model and tokenizer.
# Load model (download on first run and reference local installation for consequent runs)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# CHAT
# Keeping track of conversation history
# Initialize it.
conversation_history = []
# Encoding the conversation history
history_string = "\n".join(conversation_history)
# Fetch prompt from user
input_text ="hello, how are you doing?"
# Tokenization of user prompt and chat history
inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
print(inputs)
tokenizer.pretrained_vocab_files_map
# Generate output from the model
outputs = model.generate(**inputs)
print(outputs)
# Decode output
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
print(response)
# Update conversation history
conversation_history.append(input_text)
conversation_history.append(response)
print(conversation_history)
while True:
# Create conversation history string
history_string = "\n".join(conversation_history)
# Get the input data from the user
input_text = input("> ")
# Tokenize the input text and history
inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
# Generate the response from the model
outputs = model.generate(**inputs)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
print(response)
# Add interaction to conversation history
conversation_history.append(input_text)
conversation_history.append(response)