Coverage for /home/ubuntu/hidebound/python/hidebound/core/connection.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-05 23:50 +0000

1from typing import Any # noqa F401 

2 

3from schematics import Model 

4from schematics.types import ( 

5 BaseType, BooleanType, IntType, ListType, ModelType, StringType, URLType 

6) 

7import dask 

8import dask_gateway as dgw 

9import dask.distributed as ddist 

10 

11import hidebound.core.validators as vd 

12# ------------------------------------------------------------------------------ 

13 

14 

15class DaskConnectionConfig(Model): 

16 r''' 

17 A class for validating DaskConnection configurations. 

18 

19 Attributes: 

20 cluster_type (str, optional): Dask cluster type. Options include: 

21 local, gateway. Default: local. 

22 num_partitions (int, optional): Number of partions each DataFrame is to 

23 be split into. Default: 1. 

24 local_num_workers (int, optional): Number of workers to run on local 

25 cluster. Default: 1. 

26 local_threads_per_worker (int, optional): Number of threads to run per 

27 worker local cluster. Default: 1. 

28 local_multiprocessing (bool, optional): Whether to use multiprocessing 

29 for local cluster. Default: True. 

30 gateway_address (str, optional): Dask Gateway server address. Default: 

31 'http://proxy-public/services/dask-gateway'. 

32 gateway_proxy_address (str, optional): Dask Gateway scheduler proxy 

33 server address. 

34 Default: 'gateway://traefik-daskhub-dask-gateway.core:80' 

35 gateway_public_address (str, optional): The address to the gateway 

36 server, as accessible from a web browser. 

37 Default: 'https://dask-gateway/services/dask-gateway/'. 

38 gateway_auth_type (str, optional): Dask Gateway authentication type. 

39 Default: basic. 

40 gateway_api_token (str, optional): Authentication API token. 

41 gateway_api_user (str, optional): Basic authentication user name. 

42 gateway_cluster_options (list, optional): Dask Gateway cluster options. 

43 Default: []. 

44 gateway_min_workers (int, optional): Minimum number of Dask Gateway 

45 workers. Default: 1. 

46 gateway_max_workers (int, optional): Maximum number of Dask Gateway 

47 workers. Default: 8. 

48 gateway_shutdown_on_close (bool, optional): Whether to shudown cluster 

49 upon close. Default: True. 

50 gateway_timeout (int, optional): Dask Gateway connection timeout in 

51 seconds. Default: 30. 

52 ''' 

53 cluster_type = StringType( 

54 required=True, 

55 default='local', 

56 validators=[lambda x: vd.is_in(x, ['local', 'gateway'])] 

57 ) # type: StringType 

58 num_partitions = IntType( 

59 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)] 

60 ) # type: IntType 

61 local_num_workers = IntType( 

62 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)] 

63 ) # type: IntType 

64 local_threads_per_worker = IntType( 

65 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)] 

66 ) # type: IntType 

67 local_multiprocessing = BooleanType( 

68 required=True, default=True 

69 ) # type: BooleanType 

70 gateway_address = URLType( 

71 required=True, 

72 fqdn=False, 

73 default='http://proxy-public/services/dask-gateway', 

74 ) # type: URLType 

75 gateway_proxy_address = StringType(serialize_when_none=True) # type: StringType 

76 gateway_public_address = URLType(serialize_when_none=True, fqdn=False) # type: URLType 

77 gateway_auth_type = StringType( 

78 required=True, 

79 default='basic', 

80 validators=[lambda x: vd.is_in(x, ['basic', 'jupyterhub'])] 

81 ) # StringType 

82 gateway_api_token = StringType() # StringType 

83 gateway_api_user = StringType() # StringType 

84 gateway_min_workers = IntType( 

85 required=True, default=1, validators=[lambda x: vd.is_gte(x, 1)] 

86 ) # type: IntType 

87 gateway_max_workers = IntType( 

88 required=True, default=8, validators=[lambda x: vd.is_gte(x, 1)] 

89 ) # type: IntType 

90 gateway_shutdown_on_close = BooleanType( 

91 required=True, default=True 

92 ) # type: BooleanType 

93 gateway_timeout = IntType( 

94 required=True, default=30, validators=[lambda x: vd.is_gte(x, 1)] 

95 ) # type: IntType 

96 

97 class ClusterOption(Model): 

98 field = StringType(required=True) # type: StringType 

99 label = StringType(required=True) # type: StringType 

100 default = BaseType(required=True) # type: BaseType 

101 options = ListType(BaseType, required=True, default=[]) 

102 option_type = StringType( 

103 required=True, validators=[vd.is_cluster_option_type] 

104 ) 

105 gateway_cluster_options = ListType( 

106 ModelType(ClusterOption), required=False, default=[] 

107 ) # type: ListType 

108# ------------------------------------------------------------------------------ 

109 

110 

111# TODO: refactor so that cluster is generated upon init 

112class DaskConnection: 

113 def __init__(self, config): 

114 # type: (dict) -> None 

115 ''' 

116 Instantiates a DaskConnection. 

117 

118 Args: 

119 config (dict): DaskConnection config. 

120 

121 Raises: 

122 DataError: If config is invalid. 

123 ''' 

124 config = DaskConnectionConfig(config) 

125 config.validate() 

126 self.config = config.to_native() 

127 self.cluster = None # type: Any 

128 

129 @property 

130 def local_config(self): 

131 # type: () -> dict 

132 ''' 

133 Returns: 

134 dict: Local cluster config. 

135 ''' 

136 return dict( 

137 host='0.0.0.0', 

138 dashboard_address='0.0.0.0:8087', 

139 n_workers=self.config['local_num_workers'], 

140 threads_per_worker=self.config['local_threads_per_worker'], 

141 processes=self.config['local_multiprocessing'], 

142 ) 

143 

144 @property 

145 def gateway_config(self): 

146 # type: () -> dict 

147 ''' 

148 Returns: 

149 dict: gateway cluster config. 

150 ''' 

151 # create gateway config 

152 output = dict( 

153 address=self.config['gateway_address'], 

154 proxy_address=self.config['gateway_proxy_address'], 

155 public_address=self.config['gateway_public_address'], 

156 shutdown_on_close=self.config['gateway_shutdown_on_close'], 

157 ) 

158 

159 # set basic authentication 

160 if self.config['gateway_auth_type'] == 'basic': 

161 output['auth'] = dgw.auth.BasicAuth( 

162 username=self.config['gateway_api_user'], 

163 password=self.config['gateway_api_token'], 

164 ) 

165 

166 # set jupyterhub authentication 

167 elif self.config['gateway_auth_type'] == 'jupyterhub': 

168 output['auth'] = dgw.JupyterHubAuth( 

169 api_token=self.config['gateway_api_token'] 

170 ) 

171 

172 # set cluster options 

173 opts = self.config['gateway_cluster_options'] 

174 if len(opts) > 0: 

175 specs = [] 

176 for opt in opts: 

177 spec = dict( 

178 field=opt['field'], 

179 label=opt['label'], 

180 default=opt['default'], 

181 spec={'type': opt['option_type']}, 

182 ) 

183 if opt['option_type'] == 'select': 

184 spec['spec']['options'] = opt['options'] 

185 specs.append(spec) 

186 options = dgw.options.Options._from_spec(specs) 

187 output['cluster_options'] = options 

188 

189 return output 

190 

191 @property 

192 def cluster_type(self): 

193 # type: () -> str 

194 ''' 

195 Returns: 

196 str: Cluster type. 

197 ''' 

198 return self.config['cluster_type'] 

199 

200 @property 

201 def num_partitions(self): 

202 # type: () -> int 

203 ''' 

204 Returns: 

205 int: Number of partitions. 

206 ''' 

207 return self.config['num_partitions'] 

208 

209 def __enter__(self): 

210 # type: () -> DaskConnection 

211 ''' 

212 Creates Dask cluster and assigns it to self.cluster. 

213 

214 Returns: 

215 DaskConnection: self. 

216 ''' 

217 if self.cluster_type == 'local': 

218 self.cluster = ddist.LocalCluster(**self.local_config) 

219 elif self.cluster_type == 'gateway': # pragma: no cover 

220 dask.config.set({ 

221 'distributed.comm.timeouts.connect': self.config['gateway_timeout'] 

222 }) 

223 self.cluster = dgw.GatewayCluster(**self.gateway_config) # pragma: no cover 

224 self.cluster.adapt( 

225 minimum=self.config['gateway_min_workers'], 

226 maximum=self.config['gateway_max_workers'], 

227 ) 

228 return self 

229 

230 def __exit__(self, exc_type, exc_val, exc_tb): 

231 # type: (Any, Any, Any, Any) -> None 

232 ''' 

233 Closes Dask cluster. 

234 

235 Args: 

236 exc_type (object): Required by python. 

237 exc_val (object): Required by python. 

238 exc_tb (object): Required by python. 

239 ''' 

240 self.cluster.close()