1use std::{fmt::Display, ops::Bound, path::PathBuf};
24
25use derive_more::From;
26use futures::FutureExt;
27use hotshot_types::traits::node_implementation::NodeType;
28use serde::{Deserialize, Serialize};
29use snafu::{ResultExt, Snafu};
30use tide_disco::{api::ApiError, method::ReadState, Api, RequestError, StatusCode};
31use vbs::version::StaticVersionType;
32
33use crate::{api::load_api, availability::QueryableHeader, Header, QueryError};
34
35pub(crate) mod data_source;
36pub(crate) mod query_data;
37pub use data_source::*;
38pub use query_data::*;
39
40#[derive(Debug)]
41pub struct Options {
42 pub api_path: Option<PathBuf>,
43
44 pub extensions: Vec<toml::Value>,
49
50 pub window_limit: usize,
52}
53
54impl Default for Options {
55 fn default() -> Self {
56 Self {
57 api_path: None,
58 extensions: vec![],
59 window_limit: 500,
60 }
61 }
62}
63
64#[derive(Clone, Debug, From, Snafu, Deserialize, Serialize)]
65#[snafu(visibility(pub))]
66pub enum Error {
67 Request {
68 source: RequestError,
69 },
70 #[snafu(display("{source}"))]
71 Query {
72 source: QueryError,
73 },
74 #[snafu(display("error fetching VID share for block {block}: {source}"))]
75 #[from(ignore)]
76 QueryVid {
77 source: QueryError,
78 block: String,
79 },
80 #[snafu(display(
81 "error fetching window starting from {start} and ending at time {end}: {source}"
82 ))]
83 #[from(ignore)]
84 QueryWindow {
85 source: QueryError,
86 start: String,
87 end: u64,
88 },
89 #[snafu(display("error {status}: {message}"))]
90 Custom {
91 message: String,
92 status: StatusCode,
93 },
94}
95
96impl Error {
97 pub fn internal<M: Display>(message: M) -> Self {
98 Self::Custom {
99 message: message.to_string(),
100 status: StatusCode::INTERNAL_SERVER_ERROR,
101 }
102 }
103
104 pub fn status(&self) -> StatusCode {
105 match self {
106 Self::Request { .. } => StatusCode::BAD_REQUEST,
107 Self::Query { source, .. }
108 | Self::QueryVid { source, .. }
109 | Self::QueryWindow { source, .. } => source.status(),
110 Self::Custom { status, .. } => *status,
111 }
112 }
113}
114
115pub fn define_api<State, Types, Ver: StaticVersionType + 'static>(
116 options: &Options,
117 _: Ver,
118 api_ver: semver::Version,
119) -> Result<Api<State, Error, Ver>, ApiError>
120where
121 Types: NodeType,
122 Header<Types>: QueryableHeader<Types>,
123 State: 'static + Send + Sync + ReadState,
124 <State as ReadState>::State: NodeDataSource<Types> + Send + Sync,
125{
126 let mut api = load_api::<State, Error, Ver>(
127 options.api_path.as_ref(),
128 include_str!("../api/node.toml"),
129 options.extensions.clone(),
130 )?;
131 let window_limit = options.window_limit;
132 api.with_version(api_ver)
133 .get("block_height", |_req, state| {
134 async move { state.block_height().await.context(QuerySnafu) }.boxed()
135 })?
136 .get("count_transactions", |req, state| {
137 async move {
138 let from: Bound<usize> = match req.opt_integer_param("from")? {
139 Some(from) => Bound::Included(from),
140 None => Bound::Unbounded,
141 };
142 let to = match req.opt_integer_param("to")? {
143 Some(to) => Bound::Included(to),
144 None => Bound::Unbounded,
145 };
146
147 let ns = req.opt_integer_param::<_, i64>("namespace")?;
148
149 Ok(state
150 .count_transactions_in_range((from, to), ns.map(Into::into))
151 .await?)
152 }
153 .boxed()
154 })?
155 .get("payload_size", |req, state| {
156 async move {
157 let from: Bound<usize> = match req.opt_integer_param("from")? {
158 Some(from) => Bound::Included(from),
159 None => Bound::Unbounded,
160 };
161 let to = match req.opt_integer_param("to")? {
162 Some(to) => Bound::Included(to),
163 None => Bound::Unbounded,
164 };
165
166 let ns = req.opt_integer_param::<_, i64>("namespace")?;
167
168 Ok(state
169 .payload_size_in_range((from, to), ns.map(Into::into))
170 .await?)
171 }
172 .boxed()
173 })?
174 .get("get_vid_share", |req, state| {
175 async move {
176 let id = if let Some(height) = req.opt_integer_param("height")? {
177 BlockId::Number(height)
178 } else if let Some(hash) = req.opt_blob_param("hash")? {
179 BlockId::Hash(hash)
180 } else {
181 BlockId::PayloadHash(req.blob_param("payload-hash")?)
182 };
183 state.vid_share(id).await.context(QueryVidSnafu {
184 block: id.to_string(),
185 })
186 }
187 .boxed()
188 })?
189 .get("sync_status", |_req, state| {
190 async move { state.sync_status().await.context(QuerySnafu) }.boxed()
191 })?
192 .get("get_header_window", move |req, state| {
193 async move {
194 let start = if let Some(height) = req.opt_integer_param("height")? {
195 WindowStart::Height(height)
196 } else if let Some(hash) = req.opt_blob_param("hash")? {
197 WindowStart::Hash(hash)
198 } else {
199 WindowStart::Time(req.integer_param("start")?)
200 };
201 let end = req.integer_param("end")?;
202 state
203 .get_header_window(start, end, window_limit)
204 .await
205 .context(QueryWindowSnafu {
206 start: format!("{start:?}"),
207 end,
208 })
209 }
210 .boxed()
211 })?
212 .get("get_limits", move |_req, _state| {
213 async move { Ok(Limits { window_limit }) }.boxed()
214 })?;
215 Ok(api)
216}
217
218#[cfg(test)]
219mod test {
220 use std::time::Duration;
221
222 use async_lock::RwLock;
223 use committable::Committable;
224 use futures::{FutureExt, StreamExt};
225 use hotshot_types::{
226 data::{VidDisperseShare, VidShare},
227 event::{EventType, LeafInfo},
228 traits::{
229 block_contents::{BlockHeader, BlockPayload},
230 EncodeBytes,
231 },
232 };
233 use portpicker::pick_unused_port;
234 use surf_disco::Client;
235 use tempfile::TempDir;
236 use tide_disco::{App, Error as _};
237 use tokio::time::sleep;
238 use toml::toml;
239
240 use super::*;
241 use crate::{
242 data_source::ExtensibleDataSource,
243 task::BackgroundTask,
244 testing::{
245 consensus::{MockDataSource, MockNetwork, MockSqlDataSource},
246 mocks::{mock_transaction, MockBase, MockTypes, MockVersions},
247 setup_test,
248 },
249 ApiState, Error, Header,
250 };
251
252 #[tokio::test(flavor = "multi_thread")]
253 async fn test_api() {
254 setup_test();
255
256 let window_limit = 78;
257
258 let mut network = MockNetwork::<MockDataSource, MockVersions>::init().await;
260 let mut events = network.handle().event_stream();
261 network.start().await;
262
263 let port = pick_unused_port().unwrap();
265 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
266 app.register_module(
267 "node",
268 define_api(
269 &Options {
270 window_limit,
271 ..Default::default()
272 },
273 MockBase::instance(),
274 "1.0.0".parse().unwrap(),
275 )
276 .unwrap(),
277 )
278 .unwrap();
279 network.spawn(
280 "server",
281 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
282 );
283
284 let client = Client::<Error, MockBase>::new(
286 format!("http://localhost:{port}/node").parse().unwrap(),
287 );
288 assert!(client.connect(Some(Duration::from_secs(60))).await);
289
290 assert_eq!(
292 client.get::<Limits>("limits").send().await.unwrap(),
293 Limits { window_limit }
294 );
295
296 let block_height = loop {
298 let block_height = client.get::<usize>("block-height").send().await.unwrap();
299 if block_height > network.num_nodes() {
300 break block_height;
301 }
302 sleep(Duration::from_secs(1)).await;
303 };
304
305 assert_eq!(
308 client
309 .get::<u64>("transactions/count")
310 .send()
311 .await
312 .unwrap(),
313 0
314 );
315 assert_eq!(
316 client
317 .get::<u64>("payloads/total-size")
318 .send()
319 .await
320 .unwrap(),
321 0
322 );
323
324 let mut headers = vec![];
325
326 tracing::info!(block_height, "checking VID shares");
328 'outer: while let Some(event) = events.next().await {
329 let EventType::Decide { leaf_chain, .. } = event.event else {
330 continue;
331 };
332 for LeafInfo {
333 leaf, vid_share, ..
334 } in leaf_chain.iter().rev()
335 {
336 headers.push(leaf.block_header().clone());
337 if leaf.block_header().block_number >= block_height as u64 {
338 break 'outer;
339 }
340 tracing::info!(height = leaf.block_header().block_number, "checking share");
341
342 let share = client
343 .get::<VidShare>(&format!("vid/share/{}", leaf.block_header().block_number))
344 .send()
345 .await
346 .unwrap();
347 if let Some(vid_share) = vid_share.as_ref() {
348 let VidDisperseShare::V0(new_share) = vid_share else {
349 panic!("VID share is not V0");
350 };
351 assert_eq!(share, VidShare::V0(new_share.share.clone()));
352 }
353
354 assert_eq!(
356 share,
357 client
358 .get(&format!("vid/share/hash/{}", leaf.block_header().commit()))
359 .send()
360 .await
361 .unwrap()
362 );
363 assert_eq!(
364 share,
365 client
366 .get(&format!(
367 "vid/share/payload-hash/{}",
368 leaf.block_header().payload_commitment
369 ))
370 .send()
371 .await
372 .unwrap()
373 );
374 }
375 }
376
377 sleep(Duration::from_secs(2)).await;
382 let first_header = &headers[0];
383 let last_header = &headers.last().unwrap();
384 let window: TimeWindowQueryData<Header<MockTypes>> = client
385 .get(&format!(
386 "header/window/{}/{}",
387 first_header.timestamp,
388 last_header.timestamp + 1
389 ))
390 .send()
391 .await
392 .unwrap();
393 assert!(window.window.contains(first_header));
394 assert!(window.window.contains(last_header));
395 assert!(window.next.is_some());
396
397 assert_eq!(
399 window,
400 client
401 .get(&format!(
402 "header/window/from/0/{}",
403 last_header.timestamp + 1
404 ))
405 .send()
406 .await
407 .unwrap()
408 );
409 assert_eq!(
410 window,
411 client
412 .get(&format!(
413 "header/window/from/hash/{}/{}",
414 first_header.commit(),
415 last_header.timestamp + 1
416 ))
417 .send()
418 .await
419 .unwrap()
420 );
421
422 let sync_status = client
424 .get::<SyncStatus>("sync-status")
425 .send()
426 .await
427 .unwrap();
428 assert_eq!(sync_status.missing_blocks, 0);
429 assert_eq!(sync_status.missing_leaves, 0);
430
431 network.shut_down().await;
432 }
433
434 #[tokio::test(flavor = "multi_thread")]
435 async fn test_aggregate_ranges() {
436 setup_test();
437
438 let mut network = MockNetwork::<MockSqlDataSource, MockVersions>::init().await;
440 let mut events = network.handle().event_stream();
441 network.start().await;
442
443 let port = pick_unused_port().unwrap();
445 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
446 app.register_module(
447 "node",
448 define_api(
449 &Default::default(),
450 MockBase::instance(),
451 "1.0.0".parse().unwrap(),
452 )
453 .unwrap(),
454 )
455 .unwrap();
456 network.spawn(
457 "server",
458 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
459 );
460
461 let client =
463 Client::<Error, MockBase>::new(format!("http://localhost:{port}").parse().unwrap());
464 assert!(client.connect(Some(Duration::from_secs(60))).await);
465
466 let mut tx_heights = vec![];
468 let mut tx_sizes = vec![];
469 for i in [1, 2] {
470 let txn = mock_transaction(vec![0; i]);
471 let hash = txn.commit();
472
473 network.submit_transaction(txn).await;
474
475 let leaf = 'outer: loop {
476 let EventType::Decide { leaf_chain, .. } = events.next().await.unwrap().event
477 else {
478 continue;
479 };
480 for info in leaf_chain.iter().rev() {
481 let leaf = &info.leaf;
482 if BlockPayload::<MockTypes>::transaction_commitments(
483 &leaf.block_payload().unwrap(),
484 BlockHeader::<MockTypes>::metadata(leaf.block_header()),
485 )
486 .contains(&hash)
487 {
488 break 'outer leaf.clone();
489 }
490 }
491
492 tracing::info!("waiting for tx {i}");
493 sleep(Duration::from_secs(1)).await;
494 };
495 tx_heights.push(leaf.height());
496 tx_sizes.push(leaf.block_payload().unwrap().encode().len());
497 }
498 tracing::info!(?tx_heights, ?tx_sizes, "transactions sequenced");
499
500 while let Err(err) = client
502 .get::<usize>(&format!("node/transactions/count/{}", tx_heights[1]))
503 .send()
504 .await
505 {
506 if err.status() == StatusCode::NOT_FOUND {
507 tracing::info!(?tx_heights, "waiting for aggregator");
508 sleep(Duration::from_secs(1)).await;
509 continue;
510 } else {
511 panic!("unexpected error: {err:#}");
512 }
513 }
514
515 assert_eq!(
517 0,
518 client
519 .get::<usize>("node/transactions/count/0")
520 .send()
521 .await
522 .unwrap()
523 );
524 assert_eq!(
525 0,
526 client
527 .get::<usize>("node/payloads/size/0")
528 .send()
529 .await
530 .unwrap()
531 );
532
533 assert_eq!(
535 1,
536 client
537 .get::<usize>(&format!("node/transactions/count/{}", tx_heights[0]))
538 .send()
539 .await
540 .unwrap()
541 );
542 assert_eq!(
543 tx_sizes[0],
544 client
545 .get::<usize>(&format!("node/payloads/size/{}", tx_heights[0]))
546 .send()
547 .await
548 .unwrap()
549 );
550
551 assert_eq!(
553 1,
554 client
555 .get::<usize>(&format!(
556 "node/transactions/count/{}/{}",
557 tx_heights[0] + 1,
558 tx_heights[1]
559 ))
560 .send()
561 .await
562 .unwrap()
563 );
564 assert_eq!(
565 tx_sizes[1],
566 client
567 .get::<usize>(&format!(
568 "node/payloads/size/{}/{}",
569 tx_heights[0] + 1,
570 tx_heights[1]
571 ))
572 .send()
573 .await
574 .unwrap()
575 );
576
577 assert_eq!(
579 2,
580 client
581 .get::<usize>("node/transactions/count",)
582 .send()
583 .await
584 .unwrap()
585 );
586 assert_eq!(
587 tx_sizes[0] + tx_sizes[1],
588 client
589 .get::<usize>("node/payloads/size",)
590 .send()
591 .await
592 .unwrap()
593 );
594
595 network.shut_down().await;
596 }
597
598 #[tokio::test(flavor = "multi_thread")]
599 async fn test_extensions() {
600 setup_test();
601
602 let dir = TempDir::with_prefix("test_node_extensions").unwrap();
603 let data_source = ExtensibleDataSource::new(
604 MockDataSource::create(dir.path(), Default::default())
605 .await
606 .unwrap(),
607 0,
608 );
609
610 let extensions = toml! {
612 [route.post_ext]
613 PATH = ["/ext/:val"]
614 METHOD = "POST"
615 ":val" = "Integer"
616
617 [route.get_ext]
618 PATH = ["/ext"]
619 METHOD = "GET"
620 };
621
622 let mut api =
623 define_api::<RwLock<ExtensibleDataSource<MockDataSource, u64>>, MockTypes, MockBase>(
624 &Options {
625 extensions: vec![extensions.into()],
626 ..Default::default()
627 },
628 MockBase::instance(),
629 "1.0.0".parse().unwrap(),
630 )
631 .unwrap();
632 api.get("get_ext", |_, state| {
633 async move { Ok(*state.as_ref()) }.boxed()
634 })
635 .unwrap()
636 .post("post_ext", |req, state| {
637 async move {
638 *state.as_mut() = req.integer_param("val")?;
639 Ok(())
640 }
641 .boxed()
642 })
643 .unwrap();
644
645 let mut app = App::<_, Error>::with_state(RwLock::new(data_source));
646 app.register_module("node", api).unwrap();
647
648 let port = pick_unused_port().unwrap();
649 let _server = BackgroundTask::spawn(
650 "server",
651 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
652 );
653
654 let client = Client::<Error, MockBase>::new(
655 format!("http://localhost:{port}/node").parse().unwrap(),
656 );
657 assert!(client.connect(Some(Duration::from_secs(60))).await);
658
659 assert_eq!(client.get::<u64>("ext").send().await.unwrap(), 0);
660 client.post::<()>("ext/42").send().await.unwrap();
661 assert_eq!(client.get::<u64>("ext").send().await.unwrap(), 42);
662
663 let sync_status: SyncStatus = client.get("sync-status").send().await.unwrap();
665 assert!(sync_status.is_fully_synced(), "{sync_status:?}");
666 }
667}