3434from roboflow .util .annotations import amend_data_yaml
3535from roboflow .util .general import write_line
3636from roboflow .util .model_processor import process
37- from roboflow .util .versions import get_wrong_dependencies_versions , normalize_yolo_model_type
37+ from roboflow .util .versions import get_model_format , get_wrong_dependencies_versions , normalize_yolo_model_type
3838
3939if TYPE_CHECKING :
4040 import numpy as np
@@ -244,7 +244,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
244244
245245 return Dataset (self .name , self .version , model_format , os .path .abspath (location ))
246246
247- def export (self , model_format = None ):
247+ def export (self , model_format = None ) -> bool | None :
248248 """
249249 Ask the Roboflow API to generate a version's dataset in a given format so that it can be downloaded via the `download()` method.
250250
@@ -254,7 +254,7 @@ def export(self, model_format=None):
254254 model_format (str): A format to use for downloading
255255
256256 Returns:
257- True
257+ True if the export was successful, RuntimeError if the export failed
258258
259259 Raises:
260260 RuntimeError: If the Roboflow API returns an error with a helpful JSON body
@@ -316,18 +316,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
316316
317317 self .__wait_if_generating ()
318318
319- train_model_format = "yolov5pytorch"
320-
321- if self .type == TYPE_CLASSICATION :
322- train_model_format = "folder"
323-
324- if self .type == TYPE_INSTANCE_SEGMENTATION :
325- train_model_format = "yolov5pytorch"
326-
327- if self .type == TYPE_SEMANTIC_SEGMENTATION :
328- train_model_format = "png-mask-semantic"
329-
330- # if classification
319+ train_model_format = get_model_format (model_type )
331320 if train_model_format not in self .exports :
332321 self .export (train_model_format )
333322
0 commit comments