hotshot_libp2p_networking/network/behaviours/dht/store/
persistent.rs1use std::{
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Context;
13use async_trait::async_trait;
14use delegate::delegate;
15use libp2p::kad::store::{RecordStore, Result};
16use serde::{Deserialize, Serialize};
17use tokio::{sync::Semaphore, time::timeout};
18use tracing::{debug, warn};
19
20#[async_trait]
23pub trait DhtPersistentStorage: Send + Sync + 'static + Clone {
24 async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()>;
29
30 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>>;
35}
36
37#[derive(Clone)]
39pub struct DhtNoPersistence;
40
41#[async_trait]
42impl DhtPersistentStorage for DhtNoPersistence {
43 async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()> {
44 Ok(())
45 }
46
47 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
48 Ok(vec![])
49 }
50}
51
52#[async_trait]
53impl<D: DhtPersistentStorage> DhtPersistentStorage for Arc<D> {
54 async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
55 self.as_ref().save(records).await
56 }
57
58 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
59 self.as_ref().load().await
60 }
61}
62
63#[derive(Clone)]
66pub struct DhtFilePersistence {
67 path: String,
69}
70
71impl DhtFilePersistence {
72 #[must_use]
74 pub fn new(path: String) -> Self {
75 Self { path }
76 }
77}
78
79#[async_trait]
80impl DhtPersistentStorage for DhtFilePersistence {
81 async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
87 let to_save =
89 bincode::serialize(&records).with_context(|| "Failed to serialize records")?;
90
91 std::fs::write(&self.path, to_save).with_context(|| "Failed to write records to file")?;
93
94 Ok(())
95 }
96
97 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
103 let contents =
105 std::fs::read(&self.path).with_context(|| "Failed to read records from file")?;
106
107 let records: Vec<SerializableRecord> =
109 bincode::deserialize(&contents).with_context(|| "Failed to deserialize records")?;
110
111 Ok(records)
112 }
113}
114
115pub struct PersistentStore<R: RecordStore, D: DhtPersistentStorage> {
117 underlying_record_store: R,
119
120 persistent_storage: D,
122
123 semaphore: Arc<Semaphore>,
125
126 max_record_delta: u64,
128
129 record_delta: Arc<AtomicU64>,
131}
132
133#[derive(Serialize, Deserialize)]
135pub struct SerializableRecord {
136 pub key: libp2p::kad::RecordKey,
138 pub value: Vec<u8>,
140 pub publisher: Option<libp2p::PeerId>,
142 pub expires_unix_secs: Option<u64>,
147}
148
149fn instant_to_unix_seconds(instant: Instant) -> anyhow::Result<u64> {
151 let now_instant = Instant::now();
153 let now_system = SystemTime::now();
154
155 if instant > now_instant {
157 Ok(now_system
158 .checked_add(instant - now_instant)
159 .with_context(|| "Overflow when approximating expiration time")?
160 .duration_since(UNIX_EPOCH)
161 .with_context(|| "Failed to get duration since Unix epoch")?
162 .as_secs())
163 } else {
164 Ok(now_system
165 .checked_sub(now_instant - instant)
166 .with_context(|| "Underflow when approximating expiration time")?
167 .duration_since(UNIX_EPOCH)
168 .with_context(|| "Failed to get duration since Unix epoch")?
169 .as_secs())
170 }
171}
172
173fn unix_seconds_to_instant(unix_secs: u64) -> anyhow::Result<Instant> {
175 let now_instant = Instant::now();
177 let unix_secs_now = SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .with_context(|| "Failed to get duration since Unix epoch")?
180 .as_secs();
181
182 if unix_secs > unix_secs_now {
183 now_instant
185 .checked_add(Duration::from_secs(unix_secs - unix_secs_now))
186 .with_context(|| "Overflow when calculating future instant")
187 } else {
188 now_instant
190 .checked_sub(Duration::from_secs(unix_secs_now - unix_secs))
191 .with_context(|| "Underflow when calculating past instant")
192 }
193}
194
195impl TryFrom<libp2p::kad::Record> for SerializableRecord {
197 type Error = anyhow::Error;
198
199 fn try_from(record: libp2p::kad::Record) -> anyhow::Result<Self> {
200 Ok(SerializableRecord {
201 key: record.key,
202 value: record.value,
203 publisher: record.publisher,
204 expires_unix_secs: record.expires.map(instant_to_unix_seconds).transpose()?,
205 })
206 }
207}
208
209impl TryFrom<SerializableRecord> for libp2p::kad::Record {
211 type Error = anyhow::Error;
212
213 fn try_from(record: SerializableRecord) -> anyhow::Result<Self> {
214 Ok(libp2p::kad::Record {
215 key: record.key,
216 value: record.value,
217 publisher: record.publisher,
218 expires: record
219 .expires_unix_secs
220 .map(unix_seconds_to_instant)
221 .transpose()?,
222 })
223 }
224}
225
226impl<R: RecordStore, D: DhtPersistentStorage> PersistentStore<R, D> {
227 pub async fn new(
233 underlying_record_store: R,
234 persistent_storage: D,
235 max_record_delta: u64,
236 ) -> Self {
237 let mut store = PersistentStore {
239 underlying_record_store,
240 persistent_storage,
241 max_record_delta,
242 record_delta: Arc::new(AtomicU64::new(0)),
243 semaphore: Arc::new(Semaphore::new(1)),
244 };
245
246 if let Err(err) = store.restore_from_persistent_storage().await {
248 warn!(
249 "Failed to restore DHT from persistent storage: {:?}. Starting with empty store",
250 err
251 );
252 }
253
254 store
256 }
257
258 fn try_save_to_persistent_storage(&mut self) -> bool {
262 let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else {
264 warn!("Skipping DHT save to persistent storage - another save operation is already in progress");
265 return false;
266 };
267
268 let serializable_records: Vec<_> = self
270 .underlying_record_store
271 .records()
272 .filter_map(|record| {
273 SerializableRecord::try_from(record.into_owned())
274 .map_err(|err| {
275 warn!("Failed to convert record to serializable record: {:?}", err);
276 })
277 .ok()
278 })
279 .collect();
280
281 let persistent_storage = self.persistent_storage.clone();
283 let record_delta = Arc::clone(&self.record_delta);
284 tokio::spawn(async move {
285 debug!("Saving DHT to persistent storage");
286
287 match timeout(
289 Duration::from_secs(10),
290 persistent_storage.save(serializable_records),
291 )
292 .await
293 .map_err(|_| anyhow::anyhow!("save operation timed out"))
294 {
295 Ok(Ok(())) => {},
296 Ok(Err(error)) | Err(error) => {
297 warn!("Failed to save DHT to persistent storage: {error}");
298 },
299 };
300
301 record_delta.store(0, Ordering::Release);
303
304 drop(permit);
305
306 debug!("Saved DHT to persistent storage");
307 });
308
309 true
310 }
311
312 pub async fn restore_from_persistent_storage(&mut self) -> anyhow::Result<()> {
317 debug!("Restoring DHT from persistent storage");
318
319 let serializable_records = self
321 .persistent_storage
322 .load()
323 .await
324 .with_context(|| "Failed to read DHT from persistent storage")?;
325
326 for serializable_record in serializable_records {
328 match libp2p::kad::Record::try_from(serializable_record) {
330 Ok(record) => {
331 if record.expires.is_none() || record.expires.unwrap() < Instant::now() {
333 continue;
334 }
335
336 if let Err(err) = self.underlying_record_store.put(record) {
338 warn!(
339 "Failed to restore record from persistent storage: {:?}",
340 err
341 );
342 }
343 },
344 Err(err) => {
345 warn!("Failed to parse record from persistent storage: {:?}", err);
346 },
347 };
348 }
349
350 debug!("Restored DHT from persistent storage");
351
352 Ok(())
353 }
354}
355
356impl<R: RecordStore, D: DhtPersistentStorage> RecordStore for PersistentStore<R, D> {
358 type ProvidedIter<'a>
359 = R::ProvidedIter<'a>
360 where
361 R: 'a,
362 D: 'a;
363 type RecordsIter<'a>
364 = R::RecordsIter<'a>
365 where
366 R: 'a,
367 D: 'a;
368 delegate! {
370 to self.underlying_record_store {
371 fn add_provider(&mut self, record: libp2p::kad::ProviderRecord) -> libp2p::kad::store::Result<()>;
372 fn get(&self, k: &libp2p::kad::RecordKey) -> Option<std::borrow::Cow<'_, libp2p::kad::Record>>;
373 fn provided(&self) -> Self::ProvidedIter<'_>;
374 fn providers(&self, key: &libp2p::kad::RecordKey) -> Vec<libp2p::kad::ProviderRecord>;
375 fn records(&self) -> Self::RecordsIter<'_>;
376 fn remove_provider(&mut self, k: &libp2p::kad::RecordKey, p: &libp2p::PeerId);
377 }
378 }
379
380 fn put(&mut self, record: libp2p::kad::Record) -> Result<()> {
382 let result = self.underlying_record_store.put(record);
384
385 if result.is_ok() {
387 self.record_delta.fetch_add(1, Ordering::Relaxed);
389
390 if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
392 self.try_save_to_persistent_storage();
394 }
395 }
396
397 result
398 }
399
400 fn remove(&mut self, k: &libp2p::kad::RecordKey) {
402 self.underlying_record_store.remove(k);
404
405 self.record_delta.fetch_add(1, Ordering::Relaxed);
407
408 if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
410 self.try_save_to_persistent_storage();
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use libp2p::{
419 kad::{store::MemoryStore, RecordKey},
420 PeerId,
421 };
422 use tracing_subscriber::EnvFilter;
423
424 use super::*;
425
426 #[tokio::test]
427 async fn test_save_and_restore() {
428 let _ = tracing_subscriber::fmt()
430 .with_env_filter(EnvFilter::from_default_env())
431 .try_init();
432
433 let mut store = PersistentStore::new(
435 MemoryStore::new(PeerId::random()),
436 DhtFilePersistence::new("/tmp/test1.dht".to_string()),
437 10,
438 )
439 .await;
440
441 let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
443
444 let random_value = rand::random::<[u8; 16]>();
446
447 let mut record = libp2p::kad::Record::new(key.clone(), random_value.to_vec());
449
450 record.expires = Some(Instant::now() + Duration::from_secs(10));
452
453 store.put(record).expect("Failed to put record into store");
455
456 assert!(store.try_save_to_persistent_storage());
458
459 tokio::time::sleep(Duration::from_millis(100)).await;
461
462 let new_store = PersistentStore::new(
464 MemoryStore::new(PeerId::random()),
465 DhtFilePersistence::new("/tmp/test1.dht".to_string()),
466 10,
467 )
468 .await;
469
470 let restored_record = new_store
472 .get(&key)
473 .expect("Failed to get record from store");
474
475 assert_eq!(restored_record.value, random_value.to_vec());
477 }
478
479 #[tokio::test]
480 async fn test_record_delta() {
481 let _ = tracing_subscriber::fmt()
483 .with_env_filter(EnvFilter::from_default_env())
484 .try_init();
485
486 let mut store = PersistentStore::new(
488 MemoryStore::new(PeerId::random()),
489 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
490 10,
491 )
492 .await;
493
494 let mut keys = Vec::new();
495 let mut values = Vec::new();
496
497 for _ in 0..10 {
499 let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
501 let value = rand::random::<[u8; 16]>();
502
503 keys.push(key.clone());
504 values.push(value);
505
506 let mut record = libp2p::kad::Record::new(key, value.to_vec());
508
509 record.expires = Some(Instant::now() + Duration::from_secs(10));
511
512 store.put(record).expect("Failed to put record into store");
513 }
514
515 let new_store = PersistentStore::new(
517 MemoryStore::new(PeerId::random()),
518 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
519 10,
520 )
521 .await;
522
523 for key in &keys {
525 assert!(new_store.get(key).is_none());
526 }
527
528 let mut record = libp2p::kad::Record::new(keys[0].clone(), values[0].to_vec());
530
531 record.expires = Some(Instant::now() + Duration::from_secs(10));
533
534 store.put(record).expect("Failed to put record into store");
536
537 tokio::time::sleep(Duration::from_millis(100)).await;
539
540 let new_store = PersistentStore::new(
542 MemoryStore::new(PeerId::random()),
543 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
544 10,
545 )
546 .await;
547
548 for (i, key) in keys.iter().enumerate() {
550 let restored_record = new_store.get(key).expect("Failed to get record from store");
551 assert_eq!(restored_record.value, values[i]);
552 }
553
554 assert_eq!(store.record_delta.load(Ordering::Relaxed), 0);
556 }
557
558 #[test]
559 fn test_approximate_instant() {
560 let expiry_future = Instant::now() + Duration::from_secs(10);
562
563 let approximate_expiry =
565 unix_seconds_to_instant(instant_to_unix_seconds(expiry_future).unwrap())
566 .unwrap()
567 .duration_since(Instant::now());
568
569 assert!(approximate_expiry >= Duration::from_secs(9));
571 assert!(approximate_expiry <= Duration::from_secs(11));
572
573 let expiry_past = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
575
576 let approximate_expiry =
578 unix_seconds_to_instant(instant_to_unix_seconds(expiry_past).unwrap()).unwrap();
579 let time_difference = approximate_expiry.elapsed();
580
581 assert!(time_difference >= Duration::from_secs(9));
583 assert!(time_difference <= Duration::from_secs(11));
584 }
585}