Skip to content

Commit 44191d5

Browse files
committed
First step of having multiprocess in worker
1 parent bf1502f commit 44191d5

File tree

8 files changed

+255
-101
lines changed

8 files changed

+255
-101
lines changed

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
enron_with_categories
12
.history
23
.mypy_cache
34
docker-stack.yml

‎giskard/client/giskard_client.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __call__(self, r):
7575
class GiskardClient:
7676
def __init__(self, url: str, key: str, hf_token: str = None):
7777
self.host_url = url
78+
self.key = key
79+
self.hf_token = hf_token
7880
base_url = urljoin(url, "/api/v2/")
7981
self._session = sessions.BaseUrlSession(base_url=base_url)
8082
self._session.mount(base_url, ErrorHandlingAdapter())

‎giskard/commands/cli_worker.py‎

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from typing import Optional
2+
13
import asyncio
24
import functools
35
import logging
46
import os
57
import platform
68
import sys
7-
from typing import Optional
89

910
import click
1011
import lockfile
@@ -13,19 +14,19 @@
1314
from lockfile.pidlockfile import PIDLockFile, read_pid_from_pidfile, remove_existing_pidfile
1415
from pydantic import AnyHttpUrl
1516

16-
from giskard.cli_utils import common_options
1717
from giskard.cli_utils import (
18+
common_options,
1819
create_pid_file_path,
20+
follow_file,
21+
get_log_path,
1922
remove_stale_pid_file,
2023
run_daemon,
21-
get_log_path,
2224
tail,
23-
follow_file,
2425
validate_url,
2526
)
2627
from giskard.path_utils import run_dir
2728
from giskard.settings import settings
28-
from giskard.utils.analytics_collector import anonymize, analytics
29+
from giskard.utils.analytics_collector import analytics, anonymize
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -85,7 +86,13 @@ def wrapper(*args, **kwargs):
8586
envvar="GSK_HF_TOKEN",
8687
help="Access token for Giskard hosted in a private Hugging Face Spaces",
8788
)
88-
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
89+
@click.option(
90+
"--parallelism",
91+
"nb_workers",
92+
default=None,
93+
help="Number of processes to use for parallelism (None for number of cpu)",
94+
)
95+
def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_workers):
8996
"""\b
9097
Start ML Worker.
9198
@@ -102,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token):
102109
)
103110
api_key = initialize_api_key(api_key, is_server)
104111
hf_token = initialize_hf_token(hf_token, is_server)
105-
_start_command(is_server, url, api_key, is_daemon, hf_token)
112+
_start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers)
106113

107114

108115
def initialize_api_key(api_key, is_server):
@@ -126,7 +133,7 @@ def initialize_hf_token(hf_token, is_server):
126133
return hf_token
127134

128135

129-
def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None):
136+
def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None, nb_workers=None):
130137
from giskard.ml_worker.ml_worker import MLWorker
131138

132139
start_msg = "Starting ML Worker"
@@ -154,7 +161,7 @@ def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None
154161
run_daemon(is_server, url, api_key, hf_token)
155162
else:
156163
ml_worker = MLWorker(is_server, url, api_key, hf_token)
157-
asyncio.get_event_loop().run_until_complete(ml_worker.start())
164+
asyncio.get_event_loop().run_until_complete(ml_worker.start(nb_workers))
158165
except KeyboardInterrupt:
159166
logger.info("Exiting")
160167
if ml_worker:

‎giskard/ml_worker/ml_worker.py‎

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from typing import Optional
2+
13
import logging
24
import random
35
import secrets
4-
import stomp
56
import time
7+
8+
import stomp
69
from pydantic import AnyHttpUrl
7-
from websocket._exceptions import WebSocketException, WebSocketBadStatusException
10+
from websocket._exceptions import WebSocketBadStatusException, WebSocketException
811

912
import giskard
13+
from giskard.cli_utils import validate_url
1014
from giskard.client.giskard_client import GiskardClient
1115
from giskard.ml_worker.testing.registry.registry import load_plugins
1216
from giskard.settings import settings
13-
from giskard.cli_utils import validate_url
17+
from giskard.utils import shutdown_pool, start_pool
1418

1519
logger = logging.getLogger(__name__)
1620

@@ -140,9 +144,9 @@ def _connect_websocket_client(self, is_server=False):
140144
def is_remote_worker(self):
141145
return self.ml_worker_id is not INTERNAL_WORKER_ID
142146

143-
async def start(self):
147+
async def start(self, nb_workers: Optional[int] = None):
144148
load_plugins()
145-
149+
start_pool(nb_workers)
146150
if self.ws_conn:
147151
self.ws_stopping = False
148152
self.connect_websocket_client()
@@ -162,3 +166,4 @@ def stop(self):
162166
if self.ws_conn:
163167
self.ws_stopping = True
164168
self.ws_conn.disconnect()
169+
shutdown_pool()

‎giskard/ml_worker/websocket/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class Catalog(WorkerReply):
7272

7373

7474
class DataRow(BaseModel):
75-
columns: Dict[str, str]
75+
columns: Dict[str, str] = Field(..., repr=False)
76+
7677

7778

7879
class DataFrame(BaseModel):

0 commit comments

Comments
 (0)