11import os
2- import sys
32import csv
43import time
54import glob
65import json
76import tarfile
7+ from typing import List , Union
88from tqdm import tqdm
99from dateutil import parser as date_parser
10+ from chatterbot .chatterbot import ChatBot
1011from chatterbot .conversation import Statement
1112
1213
@@ -20,7 +21,7 @@ class Trainer(object):
2021 the environment variable if it is set.
2122 """
2223
23- def __init__ (self , chatbot , ** kwargs ):
24+ def __init__ (self , chatbot : ChatBot , ** kwargs ):
2425 self .chatbot = chatbot
2526
2627 environment_default = bool (int (os .environ .get ('CHATTERBOT_SHOW_TRAINING_PROGRESS' , True )))
@@ -30,7 +31,7 @@ def __init__(self, chatbot, **kwargs):
3031 environment_default
3132 )
3233
33- def get_preprocessed_statement (self , input_statement ) :
34+ def get_preprocessed_statement (self , input_statement : Statement ) -> Statement :
3435 """
3536 Preprocess the input statement.
3637 """
@@ -58,7 +59,7 @@ def __init__(self, message=None):
5859 )
5960 super ().__init__ (message or default )
6061
61- def _generate_export_data (self ):
62+ def _generate_export_data (self ) -> list :
6263 result = []
6364 for statement in self .chatbot .storage .filter ():
6465 if statement .in_response_to :
@@ -82,7 +83,7 @@ class ListTrainer(Trainer):
8283 where the list represents a conversation.
8384 """
8485
85- def train (self , conversation : list ):
86+ def train (self , conversation : List [ str ] ):
8687 """
8788 Train the chat bot based on the provided list of
8889 statements that represents a single conversation.
@@ -95,7 +96,6 @@ def train(self, conversation: list):
9596 # Run the pipeline in bulk to improve performance
9697 documents = self .chatbot .tagger .as_nlp_pipeline (conversation )
9798
98- # for text in enumerate(conversation):
9999 for document in tqdm (documents , desc = 'List Trainer' , disable = self .disable_progress ):
100100 statement_search_text = document ._ .search_index
101101
@@ -123,7 +123,7 @@ class ChatterBotCorpusTrainer(Trainer):
123123 ChatterBot dialog corpus.
124124 """
125125
126- def train (self , * corpus_paths ):
126+ def train (self , * corpus_paths : Union [ str , List [ str ]] ):
127127 from chatterbot .corpus import load_corpus , list_corpus_files
128128
129129 data_file_paths = []
@@ -178,7 +178,17 @@ class GenericFileTrainer(Trainer):
178178 or directory of those file types.
179179 """
180180
181- def __init__ (self , chatbot , ** kwargs ):
181+ # NOTE: If the value is an integer, this be the
182+ # column index instead of the key or header
183+ DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
184+ 'text' : 'text' ,
185+ 'conversation' : 'conversation' ,
186+ 'created_at' : 'created_at' ,
187+ 'persona' : 'persona' ,
188+ 'tags' : 'tags'
189+ }
190+
191+ def __init__ (self , chatbot : ChatBot , ** kwargs ):
182192 """
183193 data_path: str The path to the data file or directory.
184194 field_map: dict A dictionary containing the column name to header mapping.
@@ -187,22 +197,12 @@ def __init__(self, chatbot, **kwargs):
187197
188198 self .file_extension = None
189199
190- # NOTE: If the key is an integer, this be the
191- # column index instead of the key or header
192- DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
193- 'text' : 'text' ,
194- 'conversation' : 'conversation' ,
195- 'created_at' : 'created_at' ,
196- 'persona' : 'persona' ,
197- 'tags' : 'tags'
198- }
199-
200200 self .field_map = kwargs .get (
201201 'field_map' ,
202- DEFAULT_STATEMENT_TO_HEADER_MAPPING
202+ self . DEFAULT_STATEMENT_TO_HEADER_MAPPING
203203 )
204204
205- def _get_file_list (self , data_path , limit ):
205+ def _get_file_list (self , data_path : str , limit : Union [ int , None ] ):
206206 """
207207 Get a list of files to read from the data set.
208208 """
@@ -302,6 +302,20 @@ def train(self, data_path: str, limit=None):
302302 f'Current mapping: { self .field_map } '
303303 )
304304
305+ response_to_search_index_mapping = {}
306+
307+ if 'in_response_to' in self .field_map .keys ():
308+ # Generate the search_in_response_to value for the in_response_to fields
309+ response_documents = self .chatbot .tagger .as_nlp_pipeline ([
310+ (
311+ row [self .field_map ['in_response_to' ]]
312+ ) for row in data if len (row ) > 0 and row [self .field_map ['in_response_to' ]] is not None
313+ ])
314+
315+ # (Process the response values the same way as the text values)
316+ for document in response_documents :
317+ response_to_search_index_mapping [document .text ] = document ._ .search_index
318+
305319 for document , context in documents :
306320 statement = Statement (
307321 text = document .text ,
@@ -314,14 +328,19 @@ def train(self, data_path: str, limit=None):
314328 statement .created_at = date_parser .parse (context ['created_at' ])
315329
316330 statement .search_text = document ._ .search_index
317- statement .search_in_response_to = previous_statement_search_text
318331
319332 # Use the in_response_to attribute for the previous statement if
320333 # one is defined, otherwise use the last statement which was created
321334 if 'in_response_to' in self .field_map .keys ():
322335 statement .in_response_to = context .get (self .field_map ['in_response_to' ], None )
336+ statement .search_in_response_to = response_to_search_index_mapping .get (
337+ context .get (self .field_map ['in_response_to' ], None ), ''
338+ )
323339 else :
340+ # List-type data such as CSVs with no response specified can use
341+ # the previous statement as the in_response_to value
324342 statement .in_response_to = previous_statement_text
343+ statement .search_in_response_to = previous_statement_search_text
325344
326345 for preprocessor in self .chatbot .preprocessors :
327346 statement = preprocessor (statement )
@@ -345,7 +364,6 @@ def train(self, data_path: str, limit=None):
345364 )
346365 )
347366
348-
349367class CsvFileTrainer (GenericFileTrainer ):
350368 """
351369 .. note::
@@ -358,11 +376,11 @@ class CsvFileTrainer(GenericFileTrainer):
358376 parameter is set to 'tsv'.
359377
360378 :param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
361- :param str field_map: A dictionary containing the database column name to header mapping.
379+ :param dict field_map: A dictionary containing the database column name to header mapping.
362380 Values can be either the header name (str) or the column index (int).
363381 """
364382
365- def __init__ (self , chatbot , ** kwargs ):
383+ def __init__ (self , chatbot : ChatBot , ** kwargs ):
366384 super ().__init__ (chatbot , ** kwargs )
367385
368386 self .file_extension = kwargs .get ('file_extension' , 'csv' )
@@ -376,26 +394,26 @@ class JsonFileTrainer(GenericFileTrainer):
376394 Allow chatbots to be trained with data from a JSON file or
377395 directory of JSON files.
378396
379- :param str field_map: A dictionary containing the database column name to header mapping.
397+ :param dict field_map: A dictionary containing the database column name to header mapping.
380398 """
381399
382- def __init__ (self , chatbot , ** kwargs ):
400+ DEFAULT_STATEMENT_TO_KEY_MAPPING = {
401+ 'text' : 'text' ,
402+ 'conversation' : 'conversation' ,
403+ 'created_at' : 'created_at' ,
404+ 'in_response_to' : 'in_response_to' ,
405+ 'persona' : 'persona' ,
406+ 'tags' : 'tags'
407+ }
408+
409+ def __init__ (self , chatbot : ChatBot , ** kwargs ):
383410 super ().__init__ (chatbot , ** kwargs )
384411
385412 self .file_extension = 'json'
386413
387- DEFAULT_STATEMENT_TO_KEY_MAPPING = {
388- 'text' : 'text' ,
389- 'conversation' : 'conversation' ,
390- 'created_at' : 'created_at' ,
391- 'in_response_to' : 'in_response_to' ,
392- 'persona' : 'persona' ,
393- 'tags' : 'tags'
394- }
395-
396414 self .field_map = kwargs .get (
397415 'field_map' ,
398- DEFAULT_STATEMENT_TO_KEY_MAPPING
416+ self . DEFAULT_STATEMENT_TO_KEY_MAPPING
399417 )
400418
401419
@@ -412,7 +430,7 @@ class UbuntuCorpusTrainer(CsvFileTrainer):
412430 :param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
413431 """
414432
415- def __init__ (self , chatbot , ** kwargs ):
433+ def __init__ (self , chatbot : ChatBot , ** kwargs ):
416434 super ().__init__ (chatbot , ** kwargs )
417435 home_directory = os .path .expanduser ('~' )
418436
@@ -434,7 +452,7 @@ def __init__(self, chatbot, **kwargs):
434452 'persona' : 1 ,
435453 }
436454
437- def is_downloaded (self , file_path ):
455+ def is_downloaded (self , file_path : str ):
438456 """
439457 Check if the data file is already downloaded.
440458 """
@@ -444,7 +462,7 @@ def is_downloaded(self, file_path):
444462
445463 return False
446464
447- def is_extracted (self , file_path ):
465+ def is_extracted (self , file_path : str ):
448466 """
449467 Check if the data file is already extracted.
450468 """
@@ -454,7 +472,7 @@ def is_extracted(self, file_path):
454472 return True
455473 return False
456474
457- def download (self , url , show_status = True ):
475+ def download (self , url : str , show_status = True ):
458476 """
459477 Download a file from the given url.
460478 Show a progress indicator for the download status.
@@ -493,7 +511,7 @@ def download(self, url, show_status=True):
493511 print ('Download location: %s' % file_path )
494512 return file_path
495513
496- def extract (self , file_path ):
514+ def extract (self , file_path : str ):
497515 """
498516 Extract a tar file at the specified file path.
499517 """
@@ -533,7 +551,7 @@ def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
533551
534552 return True
535553
536- def _get_file_list (self , data_path , limit ):
554+ def _get_file_list (self , data_path : str , limit : Union [ int , None ] ):
537555 """
538556 Get a list of files to read from the data set.
539557 """
@@ -564,7 +582,7 @@ def _get_file_list(self, data_path, limit):
564582
565583 yield file_path
566584
567- def train (self , data_download_url , limit = None ):
585+ def train (self , data_download_url : str , limit : Union [ int , None ] = None ):
568586 """
569587 :param str data_download_url: The URL to download the Ubuntu dialog corpus from.
570588 :param int limit: The maximum number of files to train from.
0 commit comments