1+ from typing import Optional
2+
13import asyncio
24import functools
35import logging
46import os
57import platform
68import sys
7- from typing import Optional
89
910import click
1011import lockfile
1314from lockfile .pidlockfile import PIDLockFile , read_pid_from_pidfile , remove_existing_pidfile
1415from pydantic import AnyHttpUrl
1516
16- from giskard .cli_utils import common_options
1717from giskard .cli_utils import (
18+ common_options ,
1819 create_pid_file_path ,
20+ follow_file ,
21+ get_log_path ,
1922 remove_stale_pid_file ,
2023 run_daemon ,
21- get_log_path ,
2224 tail ,
23- follow_file ,
2425 validate_url ,
2526)
2627from giskard .path_utils import run_dir
2728from giskard .settings import settings
28- from giskard .utils .analytics_collector import anonymize , analytics
29+ from giskard .utils .analytics_collector import analytics , anonymize
2930
3031logger = logging .getLogger (__name__ )
3132
@@ -85,7 +86,13 @@ def wrapper(*args, **kwargs):
8586 envvar = "GSK_HF_TOKEN" ,
8687 help = "Access token for Giskard hosted in a private Hugging Face Spaces" ,
8788)
88- def start_command (url : AnyHttpUrl , is_server , api_key , is_daemon , hf_token ):
89+ @click .option (
90+ "--parallelism" ,
91+ "nb_workers" ,
92+ default = None ,
93+ help = "Number of processes to use for parallelism (None for number of cpu)" ,
94+ )
95+ def start_command (url : AnyHttpUrl , is_server , api_key , is_daemon , hf_token , nb_workers ):
8996 """\b
9097 Start ML Worker.
9198
@@ -102,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
102109 )
103110 api_key = initialize_api_key (api_key , is_server )
104111 hf_token = initialize_hf_token (hf_token , is_server )
105- _start_command (is_server , url , api_key , is_daemon , hf_token )
112+ _start_command (is_server , url , api_key , is_daemon , hf_token , nb_workers )
106113
107114
108115def initialize_api_key (api_key , is_server ):
@@ -126,7 +133,7 @@ def initialize_hf_token(hf_token, is_server):
126133 return hf_token
127134
128135
129- def _start_command (is_server , url : AnyHttpUrl , api_key , is_daemon , hf_token = None ):
136+ def _start_command (is_server , url : AnyHttpUrl , api_key , is_daemon , hf_token = None , nb_workers = None ):
130137 from giskard .ml_worker .ml_worker import MLWorker
131138
132139 start_msg = "Starting ML Worker"
@@ -154,7 +161,7 @@ def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None
154161 run_daemon (is_server , url , api_key , hf_token )
155162 else :
156163 ml_worker = MLWorker (is_server , url , api_key , hf_token )
157- asyncio .get_event_loop ().run_until_complete (ml_worker .start ())
164+ asyncio .get_event_loop ().run_until_complete (ml_worker .start (nb_workers ))
158165 except KeyboardInterrupt :
159166 logger .info ("Exiting" )
160167 if ml_worker :
0 commit comments