hotshot_query_service/fetching/provider/
any.rs1use std::{fmt::Debug, sync::Arc};
14
15use async_trait::async_trait;
16use derivative::Derivative;
17use hotshot_types::traits::node_implementation::NodeType;
18
19use super::{Provider, Request};
20use crate::{
21 availability::LeafQueryData,
22 data_source::AvailabilityProvider,
23 fetching::request::{LeafRequest, PayloadRequest, VidCommonRequest},
24 Payload, VidCommon,
25};
26
27trait DebugProvider<Types, T>: Provider<Types, T> + Debug
33where
34 Types: NodeType,
35 T: Request<Types>,
36{
37}
38
39impl<Types, T, P> DebugProvider<Types, T> for P
40where
41 Types: NodeType,
42 T: Request<Types>,
43 P: Provider<Types, T> + Debug,
44{
45}
46
47type PayloadProvider<Types> = Arc<dyn DebugProvider<Types, PayloadRequest>>;
48type LeafProvider<Types> = Arc<dyn DebugProvider<Types, LeafRequest<Types>>>;
49type VidCommonProvider<Types> = Arc<dyn DebugProvider<Types, VidCommonRequest>>;
50
51#[derive(Derivative)]
90#[derivative(Clone(bound = ""), Debug(bound = ""), Default(bound = ""))]
91pub struct AnyProvider<Types>
92where
93 Types: NodeType,
94{
95 payload_providers: Vec<PayloadProvider<Types>>,
96 leaf_providers: Vec<LeafProvider<Types>>,
97 vid_common_providers: Vec<VidCommonProvider<Types>>,
98}
99
100#[async_trait]
101impl<Types> Provider<Types, PayloadRequest> for AnyProvider<Types>
102where
103 Types: NodeType,
104{
105 async fn fetch(&self, req: PayloadRequest) -> Option<Payload<Types>> {
106 any_fetch(&self.payload_providers, req).await
107 }
108}
109
110#[async_trait]
111impl<Types> Provider<Types, LeafRequest<Types>> for AnyProvider<Types>
112where
113 Types: NodeType,
114{
115 async fn fetch(&self, req: LeafRequest<Types>) -> Option<LeafQueryData<Types>> {
116 any_fetch(&self.leaf_providers, req).await
117 }
118}
119
120#[async_trait]
121impl<Types> Provider<Types, VidCommonRequest> for AnyProvider<Types>
122where
123 Types: NodeType,
124{
125 async fn fetch(&self, req: VidCommonRequest) -> Option<VidCommon> {
126 any_fetch(&self.vid_common_providers, req).await
127 }
128}
129
130impl<Types> AnyProvider<Types>
131where
132 Types: NodeType,
133{
134 pub fn with_provider<P>(mut self, provider: P) -> Self
136 where
137 P: AvailabilityProvider<Types> + Debug + 'static,
138 {
139 let provider = Arc::new(provider);
140 self.payload_providers.push(provider.clone());
141 self.leaf_providers.push(provider.clone());
142 self.vid_common_providers.push(provider);
143 self
144 }
145
146 pub fn with_block_provider<P>(mut self, provider: P) -> Self
148 where
149 P: Provider<Types, PayloadRequest> + Debug + 'static,
150 {
151 self.payload_providers.push(Arc::new(provider));
152 self
153 }
154
155 pub fn with_leaf_provider<P>(mut self, provider: P) -> Self
157 where
158 P: Provider<Types, LeafRequest<Types>> + Debug + 'static,
159 {
160 self.leaf_providers.push(Arc::new(provider));
161 self
162 }
163
164 pub fn with_vid_common_provider<P>(mut self, provider: P) -> Self
166 where
167 P: Provider<Types, VidCommonRequest> + Debug + 'static,
168 {
169 self.vid_common_providers.push(Arc::new(provider));
170 self
171 }
172}
173
174async fn any_fetch<Types, P, T>(providers: &[Arc<P>], req: T) -> Option<T::Response>
175where
176 Types: NodeType,
177 P: Provider<Types, T> + Debug + ?Sized,
178 T: Request<Types>,
179{
180 for (i, p) in providers.iter().enumerate() {
187 match p.fetch(req).await {
188 Some(obj) => return Some(obj),
189 None => {
190 tracing::warn!(
191 "failed to fetch {req:?} from provider {i}/{}: {p:?}",
192 providers.len()
193 );
194 continue;
195 },
196 }
197 }
198
199 None
200}
201
202#[cfg(all(test, not(target_os = "windows")))]
204mod test {
205 use futures::stream::StreamExt;
206 use portpicker::pick_unused_port;
207 use tide_disco::App;
208 use vbs::version::StaticVersionType;
209
210 use super::*;
211 use crate::{
212 availability::{define_api, AvailabilityDataSource, UpdateAvailabilityData},
213 data_source::storage::sql::testing::TmpDb,
214 fetching::provider::{NoFetching, QueryServiceProvider},
215 task::BackgroundTask,
216 testing::{
217 consensus::{MockDataSource, MockNetwork},
218 mocks::{MockBase, MockTypes, MockVersions},
219 setup_test,
220 },
221 types::HeightIndexed,
222 ApiState, Error,
223 };
224
225 type Provider = AnyProvider<MockTypes>;
226
227 #[tokio::test(flavor = "multi_thread")]
228 async fn test_fetch_first_provider_fails() {
229 setup_test();
230
231 let mut network = MockNetwork::<MockDataSource, MockVersions>::init().await;
233
234 let port = pick_unused_port().unwrap();
236 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
237 app.register_module(
238 "availability",
239 define_api(
240 &Default::default(),
241 MockBase::instance(),
242 "1.0.0".parse().unwrap(),
243 )
244 .unwrap(),
245 )
246 .unwrap();
247 let _server = BackgroundTask::spawn(
248 "server",
249 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
250 );
251
252 let db = TmpDb::init().await;
254 let provider =
255 Provider::default()
256 .with_provider(NoFetching)
257 .with_provider(QueryServiceProvider::new(
258 format!("http://localhost:{port}").parse().unwrap(),
259 MockBase::instance(),
260 ));
261 let data_source = db.config().connect(provider.clone()).await.unwrap();
262
263 network.start().await;
265
266 let leaves = network.data_source().subscribe_leaves(1).await;
269 let leaves = leaves.take(3).collect::<Vec<_>>().await;
270 let test_leaf = &leaves[0];
271 let test_payload = &leaves[1];
272
273 data_source
276 .append(leaves.last().cloned().unwrap().into())
277 .await
278 .unwrap();
279
280 tracing::info!("requesting leaf from multiple providers");
281 let leaf = data_source
282 .get_leaf(test_leaf.height() as usize)
283 .await
284 .await;
285 assert_eq!(leaf, *test_leaf);
286
287 tracing::info!("requesting payload from multiple providers");
288 let payload = data_source
289 .get_payload(test_payload.height() as usize)
290 .await
291 .await;
292 assert_eq!(payload.height(), test_payload.height());
293 assert_eq!(payload.block_hash(), test_payload.block_hash());
294 assert_eq!(payload.hash(), test_payload.payload_hash());
295 }
296}