Skip to content

Commit c5ca615

Browse files
committed
Generate response search index when key is present
1 parent 58b96b6 commit c5ca615

1 file changed

Lines changed: 61 additions & 43 deletions

File tree

chatterbot/trainers.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
2-
import sys
32
import csv
43
import time
54
import glob
65
import json
76
import tarfile
7+
from typing import List, Union
88
from tqdm import tqdm
99
from dateutil import parser as date_parser
10+
from chatterbot.chatterbot import ChatBot
1011
from chatterbot.conversation import Statement
1112

1213

@@ -20,7 +21,7 @@ class Trainer(object):
2021
the environment variable if it is set.
2122
"""
2223

23-
def __init__(self, chatbot, **kwargs):
24+
def __init__(self, chatbot: ChatBot, **kwargs):
2425
self.chatbot = chatbot
2526

2627
environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))
@@ -30,7 +31,7 @@ def __init__(self, chatbot, **kwargs):
3031
environment_default
3132
)
3233

33-
def get_preprocessed_statement(self, input_statement):
34+
def get_preprocessed_statement(self, input_statement: Statement) -> Statement:
3435
"""
3536
Preprocess the input statement.
3637
"""
@@ -58,7 +59,7 @@ def __init__(self, message=None):
5859
)
5960
super().__init__(message or default)
6061

61-
def _generate_export_data(self):
62+
def _generate_export_data(self) -> list:
6263
result = []
6364
for statement in self.chatbot.storage.filter():
6465
if statement.in_response_to:
@@ -82,7 +83,7 @@ class ListTrainer(Trainer):
8283
where the list represents a conversation.
8384
"""
8485

85-
def train(self, conversation: list):
86+
def train(self, conversation: List[str]):
8687
"""
8788
Train the chat bot based on the provided list of
8889
statements that represents a single conversation.
@@ -95,7 +96,6 @@ def train(self, conversation: list):
9596
# Run the pipeline in bulk to improve performance
9697
documents = self.chatbot.tagger.as_nlp_pipeline(conversation)
9798

98-
# for text in enumerate(conversation):
9999
for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress):
100100
statement_search_text = document._.search_index
101101

@@ -123,7 +123,7 @@ class ChatterBotCorpusTrainer(Trainer):
123123
ChatterBot dialog corpus.
124124
"""
125125

126-
def train(self, *corpus_paths):
126+
def train(self, *corpus_paths: Union[str, List[str]]):
127127
from chatterbot.corpus import load_corpus, list_corpus_files
128128

129129
data_file_paths = []
@@ -178,7 +178,17 @@ class GenericFileTrainer(Trainer):
178178
or directory of those file types.
179179
"""
180180

181-
def __init__(self, chatbot, **kwargs):
181+
# NOTE: If the value is an integer, this be the
182+
# column index instead of the key or header
183+
DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
184+
'text': 'text',
185+
'conversation': 'conversation',
186+
'created_at': 'created_at',
187+
'persona': 'persona',
188+
'tags': 'tags'
189+
}
190+
191+
def __init__(self, chatbot: ChatBot, **kwargs):
182192
"""
183193
data_path: str The path to the data file or directory.
184194
field_map: dict A dictionary containing the column name to header mapping.
@@ -187,22 +197,12 @@ def __init__(self, chatbot, **kwargs):
187197

188198
self.file_extension = None
189199

190-
# NOTE: If the key is an integer, this be the
191-
# column index instead of the key or header
192-
DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
193-
'text': 'text',
194-
'conversation': 'conversation',
195-
'created_at': 'created_at',
196-
'persona': 'persona',
197-
'tags': 'tags'
198-
}
199-
200200
self.field_map = kwargs.get(
201201
'field_map',
202-
DEFAULT_STATEMENT_TO_HEADER_MAPPING
202+
self.DEFAULT_STATEMENT_TO_HEADER_MAPPING
203203
)
204204

205-
def _get_file_list(self, data_path, limit):
205+
def _get_file_list(self, data_path: str, limit: Union[int, None]):
206206
"""
207207
Get a list of files to read from the data set.
208208
"""
@@ -302,6 +302,20 @@ def train(self, data_path: str, limit=None):
302302
f'Current mapping: {self.field_map}'
303303
)
304304

305+
response_to_search_index_mapping = {}
306+
307+
if 'in_response_to' in self.field_map.keys():
308+
# Generate the search_in_response_to value for the in_response_to fields
309+
response_documents = self.chatbot.tagger.as_nlp_pipeline([
310+
(
311+
row[self.field_map['in_response_to']]
312+
) for row in data if len(row) > 0 and row[self.field_map['in_response_to']] is not None
313+
])
314+
315+
# (Process the response values the same way as the text values)
316+
for document in response_documents:
317+
response_to_search_index_mapping[document.text] = document._.search_index
318+
305319
for document, context in documents:
306320
statement = Statement(
307321
text=document.text,
@@ -314,14 +328,19 @@ def train(self, data_path: str, limit=None):
314328
statement.created_at = date_parser.parse(context['created_at'])
315329

316330
statement.search_text = document._.search_index
317-
statement.search_in_response_to = previous_statement_search_text
318331

319332
# Use the in_response_to attribute for the previous statement if
320333
# one is defined, otherwise use the last statement which was created
321334
if 'in_response_to' in self.field_map.keys():
322335
statement.in_response_to = context.get(self.field_map['in_response_to'], None)
336+
statement.search_in_response_to = response_to_search_index_mapping.get(
337+
context.get(self.field_map['in_response_to'], None), ''
338+
)
323339
else:
340+
# List-type data such as CSVs with no response specified can use
341+
# the previous statement as the in_response_to value
324342
statement.in_response_to = previous_statement_text
343+
statement.search_in_response_to = previous_statement_search_text
325344

326345
for preprocessor in self.chatbot.preprocessors:
327346
statement = preprocessor(statement)
@@ -345,7 +364,6 @@ def train(self, data_path: str, limit=None):
345364
)
346365
)
347366

348-
349367
class CsvFileTrainer(GenericFileTrainer):
350368
"""
351369
.. note::
@@ -358,11 +376,11 @@ class CsvFileTrainer(GenericFileTrainer):
358376
parameter is set to 'tsv'.
359377
360378
:param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
361-
:param str field_map: A dictionary containing the database column name to header mapping.
379+
:param dict field_map: A dictionary containing the database column name to header mapping.
362380
Values can be either the header name (str) or the column index (int).
363381
"""
364382

365-
def __init__(self, chatbot, **kwargs):
383+
def __init__(self, chatbot: ChatBot, **kwargs):
366384
super().__init__(chatbot, **kwargs)
367385

368386
self.file_extension = kwargs.get('file_extension', 'csv')
@@ -376,26 +394,26 @@ class JsonFileTrainer(GenericFileTrainer):
376394
Allow chatbots to be trained with data from a JSON file or
377395
directory of JSON files.
378396
379-
:param str field_map: A dictionary containing the database column name to header mapping.
397+
:param dict field_map: A dictionary containing the database column name to header mapping.
380398
"""
381399

382-
def __init__(self, chatbot, **kwargs):
400+
DEFAULT_STATEMENT_TO_KEY_MAPPING = {
401+
'text': 'text',
402+
'conversation': 'conversation',
403+
'created_at': 'created_at',
404+
'in_response_to': 'in_response_to',
405+
'persona': 'persona',
406+
'tags': 'tags'
407+
}
408+
409+
def __init__(self, chatbot: ChatBot, **kwargs):
383410
super().__init__(chatbot, **kwargs)
384411

385412
self.file_extension = 'json'
386413

387-
DEFAULT_STATEMENT_TO_KEY_MAPPING = {
388-
'text': 'text',
389-
'conversation': 'conversation',
390-
'created_at': 'created_at',
391-
'in_response_to': 'in_response_to',
392-
'persona': 'persona',
393-
'tags': 'tags'
394-
}
395-
396414
self.field_map = kwargs.get(
397415
'field_map',
398-
DEFAULT_STATEMENT_TO_KEY_MAPPING
416+
self.DEFAULT_STATEMENT_TO_KEY_MAPPING
399417
)
400418

401419

@@ -412,7 +430,7 @@ class UbuntuCorpusTrainer(CsvFileTrainer):
412430
:param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
413431
"""
414432

415-
def __init__(self, chatbot, **kwargs):
433+
def __init__(self, chatbot: ChatBot, **kwargs):
416434
super().__init__(chatbot, **kwargs)
417435
home_directory = os.path.expanduser('~')
418436

@@ -434,7 +452,7 @@ def __init__(self, chatbot, **kwargs):
434452
'persona': 1,
435453
}
436454

437-
def is_downloaded(self, file_path):
455+
def is_downloaded(self, file_path: str):
438456
"""
439457
Check if the data file is already downloaded.
440458
"""
@@ -444,7 +462,7 @@ def is_downloaded(self, file_path):
444462

445463
return False
446464

447-
def is_extracted(self, file_path):
465+
def is_extracted(self, file_path: str):
448466
"""
449467
Check if the data file is already extracted.
450468
"""
@@ -454,7 +472,7 @@ def is_extracted(self, file_path):
454472
return True
455473
return False
456474

457-
def download(self, url, show_status=True):
475+
def download(self, url: str, show_status=True):
458476
"""
459477
Download a file from the given url.
460478
Show a progress indicator for the download status.
@@ -493,7 +511,7 @@ def download(self, url, show_status=True):
493511
print('Download location: %s' % file_path)
494512
return file_path
495513

496-
def extract(self, file_path):
514+
def extract(self, file_path: str):
497515
"""
498516
Extract a tar file at the specified file path.
499517
"""
@@ -533,7 +551,7 @@ def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
533551

534552
return True
535553

536-
def _get_file_list(self, data_path, limit):
554+
def _get_file_list(self, data_path: str, limit: Union[int, None]):
537555
"""
538556
Get a list of files to read from the data set.
539557
"""
@@ -564,7 +582,7 @@ def _get_file_list(self, data_path, limit):
564582

565583
yield file_path
566584

567-
def train(self, data_download_url, limit=None):
585+
def train(self, data_download_url: str, limit: Union[int, None] = None):
568586
"""
569587
:param str data_download_url: The URL to download the Ubuntu dialog corpus from.
570588
:param int limit: The maximum number of files to train from.

0 commit comments

Comments
 (0)