@@ -210,7 +210,7 @@ def _get_file_list(self, data_path, limit):
210210 if self .file_extension is None :
211211 raise self .TrainerInitializationException (
212212 'The file_extension attribute must be set before calling train().'
213- )
213+ )
214214
215215 # List all csv or json files in the specified directory
216216 if os .path .isdir (data_path ):
@@ -226,7 +226,7 @@ def _get_file_list(self, data_path, limit):
226226
227227 yield file_path
228228 else :
229- return [ data_path ]
229+ yield data_path
230230
231231 def train (self , data_path : str , limit = None ):
232232 """
@@ -254,7 +254,9 @@ def train(self, data_path: str, limit=None):
254254
255255 statements_to_create = []
256256
257- with open (data_file , 'r' , encoding = 'utf-8' ) as file :
257+ file_abspath = os .path .abspath (data_file )
258+
259+ with open (file_abspath , 'r' , encoding = 'utf-8' ) as file :
258260
259261 if self .file_extension == 'json' :
260262 data = json .load (file )
@@ -281,17 +283,24 @@ def train(self, data_path: str, limit=None):
281283
282284 text_row = self .field_map ['text' ]
283285
284- documents = self .chatbot .tagger .as_nlp_pipeline ([
285- (
286- row [text_row ],
287- {
288- # Include any defined metadata columns
289- key : row [value ]
290- for key , value in self .field_map .items ()
291- if key != text_row
292- }
293- ) for row in data if len (row ) > 0
294- ])
286+ try :
287+ documents = self .chatbot .tagger .as_nlp_pipeline ([
288+ (
289+ row [text_row ],
290+ {
291+ # Include any defined metadata columns
292+ key : row [value ]
293+ for key , value in self .field_map .items ()
294+ if key != text_row
295+ }
296+ ) for row in data if len (row ) > 0
297+ ])
298+ except KeyError as e :
299+ raise KeyError (
300+ f'{ e } . Please check the field_map parameter used to initialize '
301+ f'the training class and remove this value if it is not needed. '
302+ f'Current mapping: { self .field_map } '
303+ )
295304
296305 for document , context in documents :
297306 statement = Statement (
0 commit comments