Coverage for /home/ubuntu/hidebound/python/hidebound/core/connection.py: 100%
69 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-05 23:50 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-05 23:50 +0000
1from typing import Any # noqa F401
3from schematics import Model
4from schematics.types import (
5 BaseType, BooleanType, IntType, ListType, ModelType, StringType, URLType
6)
7import dask
8import dask_gateway as dgw
9import dask.distributed as ddist
11import hidebound.core.validators as vd
12# ------------------------------------------------------------------------------
15class DaskConnectionConfig(Model):
16 r'''
17 A class for validating DaskConnection configurations.
19 Attributes:
20 cluster_type (str, optional): Dask cluster type. Options include:
21 local, gateway. Default: local.
22 num_partitions (int, optional): Number of partions each DataFrame is to
23 be split into. Default: 1.
24 local_num_workers (int, optional): Number of workers to run on local
25 cluster. Default: 1.
26 local_threads_per_worker (int, optional): Number of threads to run per
27 worker local cluster. Default: 1.
28 local_multiprocessing (bool, optional): Whether to use multiprocessing
29 for local cluster. Default: True.
30 gateway_address (str, optional): Dask Gateway server address. Default:
31 'http://proxy-public/services/dask-gateway'.
32 gateway_proxy_address (str, optional): Dask Gateway scheduler proxy
33 server address.
34 Default: 'gateway://traefik-daskhub-dask-gateway.core:80'
35 gateway_public_address (str, optional): The address to the gateway
36 server, as accessible from a web browser.
37 Default: 'https://dask-gateway/services/dask-gateway/'.
38 gateway_auth_type (str, optional): Dask Gateway authentication type.
39 Default: basic.
40 gateway_api_token (str, optional): Authentication API token.
41 gateway_api_user (str, optional): Basic authentication user name.
42 gateway_cluster_options (list, optional): Dask Gateway cluster options.
43 Default: [].
44 gateway_min_workers (int, optional): Minimum number of Dask Gateway
45 workers. Default: 1.
46 gateway_max_workers (int, optional): Maximum number of Dask Gateway
47 workers. Default: 8.
48 gateway_shutdown_on_close (bool, optional): Whether to shudown cluster
49 upon close. Default: True.
50 gateway_timeout (int, optional): Dask Gateway connection timeout in
51 seconds. Default: 30.
52 '''
53 cluster_type = StringType(
54 required=True,
55 default='local',
56 validators=[lambda x: vd.is_in(x, ['local', 'gateway'])]
57 ) # type: StringType
58 num_partitions = IntType(
59 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)]
60 ) # type: IntType
61 local_num_workers = IntType(
62 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)]
63 ) # type: IntType
64 local_threads_per_worker = IntType(
65 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)]
66 ) # type: IntType
67 local_multiprocessing = BooleanType(
68 required=True, default=True
69 ) # type: BooleanType
70 gateway_address = URLType(
71 required=True,
72 fqdn=False,
73 default='http://proxy-public/services/dask-gateway',
74 ) # type: URLType
75 gateway_proxy_address = StringType(serialize_when_none=True) # type: StringType
76 gateway_public_address = URLType(serialize_when_none=True, fqdn=False) # type: URLType
77 gateway_auth_type = StringType(
78 required=True,
79 default='basic',
80 validators=[lambda x: vd.is_in(x, ['basic', 'jupyterhub'])]
81 ) # StringType
82 gateway_api_token = StringType() # StringType
83 gateway_api_user = StringType() # StringType
84 gateway_min_workers = IntType(
85 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)]
86 ) # type: IntType
87 gateway_max_workers = IntType(
88 required=True, default=8, validators=[lambda x: vd.is_gte(x, 1)]
89 ) # type: IntType
90 gateway_shutdown_on_close = BooleanType(
91 required=True, default=True
92 ) # type: BooleanType
93 gateway_timeout = IntType(
94 required=True, default=30, validators=[lambda x: vd.is_gte(x, 1)]
95 ) # type: IntType
97 class ClusterOption(Model):
98 field = StringType(required=True) # type: StringType
99 label = StringType(required=True) # type: StringType
100 default = BaseType(required=True) # type: BaseType
101 options = ListType(BaseType, required=True, default=[])
102 option_type = StringType(
103 required=True, validators=[vd.is_cluster_option_type]
104 )
105 gateway_cluster_options = ListType(
106 ModelType(ClusterOption), required=False, default=[]
107 ) # type: ListType
108# ------------------------------------------------------------------------------
111# TODO: refactor so that cluster is generated upon init
112class DaskConnection:
113 def __init__(self, config):
114 # type: (dict) -> None
115 '''
116 Instantiates a DaskConnection.
118 Args:
119 config (dict): DaskConnection config.
121 Raises:
122 DataError: If config is invalid.
123 '''
124 config = DaskConnectionConfig(config)
125 config.validate()
126 self.config = config.to_native()
127 self.cluster = None # type: Any
129 @property
130 def local_config(self):
131 # type: () -> dict
132 '''
133 Returns:
134 dict: Local cluster config.
135 '''
136 return dict(
137 host='0.0.0.0',
138 dashboard_address='0.0.0.0:8087',
139 n_workers=self.config['local_num_workers'],
140 threads_per_worker=self.config['local_threads_per_worker'],
141 processes=self.config['local_multiprocessing'],
142 )
144 @property
145 def gateway_config(self):
146 # type: () -> dict
147 '''
148 Returns:
149 dict: gateway cluster config.
150 '''
151 # create gateway config
152 output = dict(
153 address=self.config['gateway_address'],
154 proxy_address=self.config['gateway_proxy_address'],
155 public_address=self.config['gateway_public_address'],
156 shutdown_on_close=self.config['gateway_shutdown_on_close'],
157 )
159 # set basic authentication
160 if self.config['gateway_auth_type'] == 'basic':
161 output['auth'] = dgw.auth.BasicAuth(
162 username=self.config['gateway_api_user'],
163 password=self.config['gateway_api_token'],
164 )
166 # set jupyterhub authentication
167 elif self.config['gateway_auth_type'] == 'jupyterhub':
168 output['auth'] = dgw.JupyterHubAuth(
169 api_token=self.config['gateway_api_token']
170 )
172 # set cluster options
173 opts = self.config['gateway_cluster_options']
174 if len(opts) > 0:
175 specs = []
176 for opt in opts:
177 spec = dict(
178 field=opt['field'],
179 label=opt['label'],
180 default=opt['default'],
181 spec={'type': opt['option_type']},
182 )
183 if opt['option_type'] == 'select':
184 spec['spec']['options'] = opt['options']
185 specs.append(spec)
186 options = dgw.options.Options._from_spec(specs)
187 output['cluster_options'] = options
189 return output
191 @property
192 def cluster_type(self):
193 # type: () -> str
194 '''
195 Returns:
196 str: Cluster type.
197 '''
198 return self.config['cluster_type']
200 @property
201 def num_partitions(self):
202 # type: () -> int
203 '''
204 Returns:
205 int: Number of partitions.
206 '''
207 return self.config['num_partitions']
209 def __enter__(self):
210 # type: () -> DaskConnection
211 '''
212 Creates Dask cluster and assigns it to self.cluster.
214 Returns:
215 DaskConnection: self.
216 '''
217 if self.cluster_type == 'local':
218 self.cluster = ddist.LocalCluster(**self.local_config)
219 elif self.cluster_type == 'gateway': # pragma: no cover
220 dask.config.set({
221 'distributed.comm.timeouts.connect': self.config['gateway_timeout']
222 })
223 self.cluster = dgw.GatewayCluster(**self.gateway_config) # pragma: no cover
224 self.cluster.adapt(
225 minimum=self.config['gateway_min_workers'],
226 maximum=self.config['gateway_max_workers'],
227 )
228 return self
230 def __exit__(self, exc_type, exc_val, exc_tb):
231 # type: (Any, Any, Any, Any) -> None
232 '''
233 Closes Dask cluster.
235 Args:
236 exc_type (object): Required by python.
237 exc_val (object): Required by python.
238 exc_tb (object): Required by python.
239 '''
240 self.cluster.close()