1use std::{fmt::Display, ops::Bound, path::PathBuf};
24
25use derive_more::From;
26use futures::FutureExt;
27use hotshot_types::traits::node_implementation::NodeType;
28use serde::{Deserialize, Serialize};
29use snafu::{ResultExt, Snafu};
30use tide_disco::{api::ApiError, method::ReadState, Api, RequestError, StatusCode};
31use vbs::version::StaticVersionType;
32
33use crate::{api::load_api, availability::QueryableHeader, Header, QueryError};
34
35pub(crate) mod data_source;
36pub(crate) mod query_data;
37pub use data_source::*;
38pub use query_data::*;
39
40#[derive(Debug)]
41pub struct Options {
42 pub api_path: Option<PathBuf>,
43
44 pub extensions: Vec<toml::Value>,
49
50 pub window_limit: usize,
52}
53
54impl Default for Options {
55 fn default() -> Self {
56 Self {
57 api_path: None,
58 extensions: vec![],
59 window_limit: 500,
60 }
61 }
62}
63
64#[derive(Clone, Debug, From, Snafu, Deserialize, Serialize)]
65#[snafu(visibility(pub))]
66pub enum Error {
67 Request {
68 source: RequestError,
69 },
70 #[snafu(display("{source}"))]
71 Query {
72 source: QueryError,
73 },
74 #[snafu(display("error fetching VID share for block {block}: {source}"))]
75 #[from(ignore)]
76 QueryVid {
77 source: QueryError,
78 block: String,
79 },
80 #[snafu(display(
81 "error fetching window starting from {start} and ending at time {end}: {source}"
82 ))]
83 #[from(ignore)]
84 QueryWindow {
85 source: QueryError,
86 start: String,
87 end: u64,
88 },
89 #[snafu(display("error {status}: {message}"))]
90 Custom {
91 message: String,
92 status: StatusCode,
93 },
94}
95
96impl Error {
97 pub fn internal<M: Display>(message: M) -> Self {
98 Self::Custom {
99 message: message.to_string(),
100 status: StatusCode::INTERNAL_SERVER_ERROR,
101 }
102 }
103
104 pub fn status(&self) -> StatusCode {
105 match self {
106 Self::Request { .. } => StatusCode::BAD_REQUEST,
107 Self::Query { source, .. }
108 | Self::QueryVid { source, .. }
109 | Self::QueryWindow { source, .. } => source.status(),
110 Self::Custom { status, .. } => *status,
111 }
112 }
113}
114
115pub fn define_api<State, Types, Ver: StaticVersionType + 'static>(
116 options: &Options,
117 _: Ver,
118 api_ver: semver::Version,
119) -> Result<Api<State, Error, Ver>, ApiError>
120where
121 Types: NodeType,
122 Header<Types>: QueryableHeader<Types>,
123 State: 'static + Send + Sync + ReadState,
124 <State as ReadState>::State: NodeDataSource<Types> + Send + Sync,
125{
126 let mut api = load_api::<State, Error, Ver>(
127 options.api_path.as_ref(),
128 include_str!("../api/node.toml"),
129 options.extensions.clone(),
130 )?;
131 let window_limit = options.window_limit;
132 api.with_version(api_ver)
133 .get("block_height", |_req, state| {
134 async move { state.block_height().await.context(QuerySnafu) }.boxed()
135 })?
136 .get("count_transactions", |req, state| {
137 async move {
138 let from: Bound<usize> = match req.opt_integer_param("from")? {
139 Some(from) => Bound::Included(from),
140 None => Bound::Unbounded,
141 };
142 let to = match req.opt_integer_param("to")? {
143 Some(to) => Bound::Included(to),
144 None => Bound::Unbounded,
145 };
146
147 let ns = req.opt_integer_param::<_, i64>("namespace")?;
148
149 Ok(state
150 .count_transactions_in_range((from, to), ns.map(Into::into))
151 .await?)
152 }
153 .boxed()
154 })?
155 .get("payload_size", |req, state| {
156 async move {
157 let from: Bound<usize> = match req.opt_integer_param("from")? {
158 Some(from) => Bound::Included(from),
159 None => Bound::Unbounded,
160 };
161 let to = match req.opt_integer_param("to")? {
162 Some(to) => Bound::Included(to),
163 None => Bound::Unbounded,
164 };
165
166 let ns = req.opt_integer_param::<_, i64>("namespace")?;
167
168 Ok(state
169 .payload_size_in_range((from, to), ns.map(Into::into))
170 .await?)
171 }
172 .boxed()
173 })?
174 .get("get_vid_share", |req, state| {
175 async move {
176 let id = if let Some(height) = req.opt_integer_param("height")? {
177 BlockId::Number(height)
178 } else if let Some(hash) = req.opt_blob_param("hash")? {
179 BlockId::Hash(hash)
180 } else {
181 BlockId::PayloadHash(req.blob_param("payload-hash")?)
182 };
183 state.vid_share(id).await.context(QueryVidSnafu {
184 block: id.to_string(),
185 })
186 }
187 .boxed()
188 })?
189 .get("sync_status", |_req, state| {
190 async move { state.sync_status().await.context(QuerySnafu) }.boxed()
191 })?
192 .get("get_header_window", move |req, state| {
193 async move {
194 let start = if let Some(height) = req.opt_integer_param("height")? {
195 WindowStart::Height(height)
196 } else if let Some(hash) = req.opt_blob_param("hash")? {
197 WindowStart::Hash(hash)
198 } else {
199 WindowStart::Time(req.integer_param("start")?)
200 };
201 let end = req.integer_param("end")?;
202 state
203 .get_header_window(start, end, window_limit)
204 .await
205 .context(QueryWindowSnafu {
206 start: format!("{start:?}"),
207 end,
208 })
209 }
210 .boxed()
211 })?
212 .get("get_limits", move |_req, _state| {
213 async move { Ok(Limits { window_limit }) }.boxed()
214 })?;
215 Ok(api)
216}
217
218#[cfg(test)]
219mod test {
220 use std::time::Duration;
221
222 use async_lock::RwLock;
223 use committable::Committable;
224 use futures::{FutureExt, StreamExt};
225 use hotshot_types::{
226 data::{VidDisperseShare, VidShare},
227 event::{EventType, LeafInfo},
228 traits::{
229 block_contents::{BlockHeader, BlockPayload},
230 EncodeBytes,
231 },
232 };
233 use portpicker::pick_unused_port;
234 use surf_disco::Client;
235 use tempfile::TempDir;
236 use tide_disco::{App, Error as _};
237 use tokio::time::sleep;
238 use toml::toml;
239
240 use super::*;
241 use crate::{
242 data_source::ExtensibleDataSource,
243 task::BackgroundTask,
244 testing::{
245 consensus::{MockDataSource, MockNetwork, MockSqlDataSource},
246 mocks::{mock_transaction, MockBase, MockTypes, MockVersions},
247 },
248 ApiState, Error, Header,
249 };
250
251 #[test_log::test(tokio::test(flavor = "multi_thread"))]
252 async fn test_api() {
253 let window_limit = 78;
254
255 let mut network = MockNetwork::<MockDataSource, MockVersions>::init().await;
257 let mut events = network.handle().event_stream();
258 network.start().await;
259
260 let port = pick_unused_port().unwrap();
262 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
263 app.register_module(
264 "node",
265 define_api(
266 &Options {
267 window_limit,
268 ..Default::default()
269 },
270 MockBase::instance(),
271 "1.0.0".parse().unwrap(),
272 )
273 .unwrap(),
274 )
275 .unwrap();
276 network.spawn(
277 "server",
278 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
279 );
280
281 let client = Client::<Error, MockBase>::new(
283 format!("http://localhost:{port}/node").parse().unwrap(),
284 );
285 assert!(client.connect(Some(Duration::from_secs(60))).await);
286
287 assert_eq!(
289 client.get::<Limits>("limits").send().await.unwrap(),
290 Limits { window_limit }
291 );
292
293 let block_height = loop {
295 let block_height = client.get::<usize>("block-height").send().await.unwrap();
296 if block_height > network.num_nodes() {
297 break block_height;
298 }
299 sleep(Duration::from_secs(1)).await;
300 };
301
302 assert_eq!(
305 client
306 .get::<u64>("transactions/count")
307 .send()
308 .await
309 .unwrap(),
310 0
311 );
312 assert_eq!(
313 client
314 .get::<u64>("payloads/total-size")
315 .send()
316 .await
317 .unwrap(),
318 0
319 );
320
321 let mut headers = vec![];
322
323 tracing::info!(block_height, "checking VID shares");
325 'outer: while let Some(event) = events.next().await {
326 let EventType::Decide { leaf_chain, .. } = event.event else {
327 continue;
328 };
329 for LeafInfo {
330 leaf, vid_share, ..
331 } in leaf_chain.iter().rev()
332 {
333 headers.push(leaf.block_header().clone());
334 if leaf.block_header().block_number >= block_height as u64 {
335 break 'outer;
336 }
337 tracing::info!(height = leaf.block_header().block_number, "checking share");
338
339 let share = client
340 .get::<VidShare>(&format!("vid/share/{}", leaf.block_header().block_number))
341 .send()
342 .await
343 .unwrap();
344 if let Some(vid_share) = vid_share.as_ref() {
345 let VidDisperseShare::V0(new_share) = vid_share else {
346 panic!("VID share is not V0");
347 };
348 assert_eq!(share, VidShare::V0(new_share.share.clone()));
349 }
350
351 assert_eq!(
353 share,
354 client
355 .get(&format!("vid/share/hash/{}", leaf.block_header().commit()))
356 .send()
357 .await
358 .unwrap()
359 );
360 assert_eq!(
361 share,
362 client
363 .get(&format!(
364 "vid/share/payload-hash/{}",
365 leaf.block_header().payload_commitment
366 ))
367 .send()
368 .await
369 .unwrap()
370 );
371 }
372 }
373
374 sleep(Duration::from_secs(2)).await;
379 let first_header = &headers[0];
380 let last_header = &headers.last().unwrap();
381 let window: TimeWindowQueryData<Header<MockTypes>> = client
382 .get(&format!(
383 "header/window/{}/{}",
384 first_header.timestamp,
385 last_header.timestamp + 1
386 ))
387 .send()
388 .await
389 .unwrap();
390 assert!(window.window.contains(first_header));
391 assert!(window.window.contains(last_header));
392 assert!(window.next.is_some());
393
394 assert_eq!(
396 window,
397 client
398 .get(&format!(
399 "header/window/from/0/{}",
400 last_header.timestamp + 1
401 ))
402 .send()
403 .await
404 .unwrap()
405 );
406 assert_eq!(
407 window,
408 client
409 .get(&format!(
410 "header/window/from/hash/{}/{}",
411 first_header.commit(),
412 last_header.timestamp + 1
413 ))
414 .send()
415 .await
416 .unwrap()
417 );
418
419 let sync_status = client
421 .get::<SyncStatus>("sync-status")
422 .send()
423 .await
424 .unwrap();
425 assert_eq!(sync_status.missing_blocks, 0);
426 assert_eq!(sync_status.missing_leaves, 0);
427
428 network.shut_down().await;
429 }
430
431 #[test_log::test(tokio::test(flavor = "multi_thread"))]
432 async fn test_aggregate_ranges() {
433 let mut network = MockNetwork::<MockSqlDataSource, MockVersions>::init().await;
435 let mut events = network.handle().event_stream();
436 network.start().await;
437
438 let port = pick_unused_port().unwrap();
440 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
441 app.register_module(
442 "node",
443 define_api(
444 &Default::default(),
445 MockBase::instance(),
446 "1.0.0".parse().unwrap(),
447 )
448 .unwrap(),
449 )
450 .unwrap();
451 network.spawn(
452 "server",
453 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
454 );
455
456 let client =
458 Client::<Error, MockBase>::new(format!("http://localhost:{port}").parse().unwrap());
459 assert!(client.connect(Some(Duration::from_secs(60))).await);
460
461 let mut tx_heights = vec![];
463 let mut tx_sizes = vec![];
464 for i in [1, 2] {
465 let txn = mock_transaction(vec![0; i]);
466 let hash = txn.commit();
467
468 network.submit_transaction(txn).await;
469
470 let leaf = 'outer: loop {
471 let EventType::Decide { leaf_chain, .. } = events.next().await.unwrap().event
472 else {
473 continue;
474 };
475 for info in leaf_chain.iter().rev() {
476 let leaf = &info.leaf;
477 if BlockPayload::<MockTypes>::transaction_commitments(
478 &leaf.block_payload().unwrap(),
479 BlockHeader::<MockTypes>::metadata(leaf.block_header()),
480 )
481 .contains(&hash)
482 {
483 break 'outer leaf.clone();
484 }
485 }
486
487 tracing::info!("waiting for tx {i}");
488 sleep(Duration::from_secs(1)).await;
489 };
490 tx_heights.push(leaf.height());
491 tx_sizes.push(leaf.block_payload().unwrap().encode().len());
492 }
493 tracing::info!(?tx_heights, ?tx_sizes, "transactions sequenced");
494
495 while let Err(err) = client
497 .get::<usize>(&format!("node/transactions/count/{}", tx_heights[1]))
498 .send()
499 .await
500 {
501 if err.status() == StatusCode::NOT_FOUND {
502 tracing::info!(?tx_heights, "waiting for aggregator");
503 sleep(Duration::from_secs(1)).await;
504 continue;
505 } else {
506 panic!("unexpected error: {err:#}");
507 }
508 }
509
510 assert_eq!(
512 0,
513 client
514 .get::<usize>("node/transactions/count/0")
515 .send()
516 .await
517 .unwrap()
518 );
519 assert_eq!(
520 0,
521 client
522 .get::<usize>("node/payloads/size/0")
523 .send()
524 .await
525 .unwrap()
526 );
527
528 assert_eq!(
530 1,
531 client
532 .get::<usize>(&format!("node/transactions/count/{}", tx_heights[0]))
533 .send()
534 .await
535 .unwrap()
536 );
537 assert_eq!(
538 tx_sizes[0],
539 client
540 .get::<usize>(&format!("node/payloads/size/{}", tx_heights[0]))
541 .send()
542 .await
543 .unwrap()
544 );
545
546 assert_eq!(
548 1,
549 client
550 .get::<usize>(&format!(
551 "node/transactions/count/{}/{}",
552 tx_heights[0] + 1,
553 tx_heights[1]
554 ))
555 .send()
556 .await
557 .unwrap()
558 );
559 assert_eq!(
560 tx_sizes[1],
561 client
562 .get::<usize>(&format!(
563 "node/payloads/size/{}/{}",
564 tx_heights[0] + 1,
565 tx_heights[1]
566 ))
567 .send()
568 .await
569 .unwrap()
570 );
571
572 assert_eq!(
574 2,
575 client
576 .get::<usize>("node/transactions/count",)
577 .send()
578 .await
579 .unwrap()
580 );
581 assert_eq!(
582 tx_sizes[0] + tx_sizes[1],
583 client
584 .get::<usize>("node/payloads/size",)
585 .send()
586 .await
587 .unwrap()
588 );
589
590 network.shut_down().await;
591 }
592
593 #[test_log::test(tokio::test(flavor = "multi_thread"))]
594 async fn test_extensions() {
595 let dir = TempDir::with_prefix("test_node_extensions").unwrap();
596 let data_source = ExtensibleDataSource::new(
597 MockDataSource::create(dir.path(), Default::default())
598 .await
599 .unwrap(),
600 0,
601 );
602
603 let extensions = toml! {
605 [route.post_ext]
606 PATH = ["/ext/:val"]
607 METHOD = "POST"
608 ":val" = "Integer"
609
610 [route.get_ext]
611 PATH = ["/ext"]
612 METHOD = "GET"
613 };
614
615 let mut api =
616 define_api::<RwLock<ExtensibleDataSource<MockDataSource, u64>>, MockTypes, MockBase>(
617 &Options {
618 extensions: vec![extensions.into()],
619 ..Default::default()
620 },
621 MockBase::instance(),
622 "1.0.0".parse().unwrap(),
623 )
624 .unwrap();
625 api.get("get_ext", |_, state| {
626 async move { Ok(*state.as_ref()) }.boxed()
627 })
628 .unwrap()
629 .post("post_ext", |req, state| {
630 async move {
631 *state.as_mut() = req.integer_param("val")?;
632 Ok(())
633 }
634 .boxed()
635 })
636 .unwrap();
637
638 let mut app = App::<_, Error>::with_state(RwLock::new(data_source));
639 app.register_module("node", api).unwrap();
640
641 let port = pick_unused_port().unwrap();
642 let _server = BackgroundTask::spawn(
643 "server",
644 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
645 );
646
647 let client = Client::<Error, MockBase>::new(
648 format!("http://localhost:{port}/node").parse().unwrap(),
649 );
650 assert!(client.connect(Some(Duration::from_secs(60))).await);
651
652 assert_eq!(client.get::<u64>("ext").send().await.unwrap(), 0);
653 client.post::<()>("ext/42").send().await.unwrap();
654 assert_eq!(client.get::<u64>("ext").send().await.unwrap(), 42);
655
656 let sync_status: SyncStatus = client.get("sync-status").send().await.unwrap();
658 assert!(sync_status.is_fully_synced(), "{sync_status:?}");
659 }
660}