hotshot_libp2p_networking/network/behaviours/dht/store/
persistent.rs1use std::{
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Context;
13use async_trait::async_trait;
14use delegate::delegate;
15use libp2p::kad::store::{RecordStore, Result};
16use serde::{Deserialize, Serialize};
17use tokio::{sync::Semaphore, time::timeout};
18use tracing::{debug, warn};
19
20#[async_trait]
23pub trait DhtPersistentStorage: Send + Sync + 'static + Clone {
24 async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()>;
29
30 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>>;
35}
36
37#[derive(Clone)]
39pub struct DhtNoPersistence;
40
41#[async_trait]
42impl DhtPersistentStorage for DhtNoPersistence {
43 async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()> {
44 Ok(())
45 }
46
47 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
48 Ok(vec![])
49 }
50}
51
52#[async_trait]
53impl<D: DhtPersistentStorage> DhtPersistentStorage for Arc<D> {
54 async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
55 self.as_ref().save(records).await
56 }
57
58 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
59 self.as_ref().load().await
60 }
61}
62
63#[derive(Clone)]
66pub struct DhtFilePersistence {
67 path: String,
69}
70
71impl DhtFilePersistence {
72 #[must_use]
74 pub fn new(path: String) -> Self {
75 Self { path }
76 }
77}
78
79#[async_trait]
80impl DhtPersistentStorage for DhtFilePersistence {
81 async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
87 let to_save =
89 bincode::serialize(&records).with_context(|| "Failed to serialize records")?;
90
91 std::fs::write(&self.path, to_save).with_context(|| "Failed to write records to file")?;
93
94 Ok(())
95 }
96
97 async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
103 let contents =
105 std::fs::read(&self.path).with_context(|| "Failed to read records from file")?;
106
107 let records: Vec<SerializableRecord> =
109 bincode::deserialize(&contents).with_context(|| "Failed to deserialize records")?;
110
111 Ok(records)
112 }
113}
114
115pub struct PersistentStore<R: RecordStore, D: DhtPersistentStorage> {
117 underlying_record_store: R,
119
120 persistent_storage: D,
122
123 semaphore: Arc<Semaphore>,
125
126 max_record_delta: u64,
128
129 record_delta: Arc<AtomicU64>,
131}
132
133#[derive(Serialize, Deserialize)]
135pub struct SerializableRecord {
136 pub key: libp2p::kad::RecordKey,
138 pub value: Vec<u8>,
140 pub publisher: Option<libp2p::PeerId>,
142 pub expires_unix_secs: Option<u64>,
147}
148
149fn instant_to_unix_seconds(instant: Instant) -> anyhow::Result<u64> {
151 let now_instant = Instant::now();
153 let now_system = SystemTime::now();
154
155 if instant > now_instant {
157 Ok(now_system
158 .checked_add(instant - now_instant)
159 .with_context(|| "Overflow when approximating expiration time")?
160 .duration_since(UNIX_EPOCH)
161 .with_context(|| "Failed to get duration since Unix epoch")?
162 .as_secs())
163 } else {
164 Ok(now_system
165 .checked_sub(now_instant - instant)
166 .with_context(|| "Underflow when approximating expiration time")?
167 .duration_since(UNIX_EPOCH)
168 .with_context(|| "Failed to get duration since Unix epoch")?
169 .as_secs())
170 }
171}
172
173fn unix_seconds_to_instant(unix_secs: u64) -> anyhow::Result<Instant> {
175 let now_instant = Instant::now();
177 let unix_secs_now = SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .with_context(|| "Failed to get duration since Unix epoch")?
180 .as_secs();
181
182 if unix_secs > unix_secs_now {
183 now_instant
185 .checked_add(Duration::from_secs(unix_secs - unix_secs_now))
186 .with_context(|| "Overflow when calculating future instant")
187 } else {
188 now_instant
190 .checked_sub(Duration::from_secs(unix_secs_now - unix_secs))
191 .with_context(|| "Underflow when calculating past instant")
192 }
193}
194
195impl TryFrom<libp2p::kad::Record> for SerializableRecord {
197 type Error = anyhow::Error;
198
199 fn try_from(record: libp2p::kad::Record) -> anyhow::Result<Self> {
200 Ok(SerializableRecord {
201 key: record.key,
202 value: record.value,
203 publisher: record.publisher,
204 expires_unix_secs: record.expires.map(instant_to_unix_seconds).transpose()?,
205 })
206 }
207}
208
209impl TryFrom<SerializableRecord> for libp2p::kad::Record {
211 type Error = anyhow::Error;
212
213 fn try_from(record: SerializableRecord) -> anyhow::Result<Self> {
214 Ok(libp2p::kad::Record {
215 key: record.key,
216 value: record.value,
217 publisher: record.publisher,
218 expires: record
219 .expires_unix_secs
220 .map(unix_seconds_to_instant)
221 .transpose()?,
222 })
223 }
224}
225
226impl<R: RecordStore, D: DhtPersistentStorage> PersistentStore<R, D> {
227 pub async fn new(
233 underlying_record_store: R,
234 persistent_storage: D,
235 max_record_delta: u64,
236 ) -> Self {
237 let mut store = PersistentStore {
239 underlying_record_store,
240 persistent_storage,
241 max_record_delta,
242 record_delta: Arc::new(AtomicU64::new(0)),
243 semaphore: Arc::new(Semaphore::new(1)),
244 };
245
246 if let Err(err) = store.restore_from_persistent_storage().await {
248 warn!(
249 "Failed to restore DHT from persistent storage: {err}. Starting with empty store",
250 );
251 }
252
253 store
255 }
256
257 fn try_save_to_persistent_storage(&mut self) -> bool {
261 let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else {
263 warn!(
264 "Skipping DHT save to persistent storage - another save operation is already in \
265 progress"
266 );
267 return false;
268 };
269
270 let serializable_records: Vec<_> = self
272 .underlying_record_store
273 .records()
274 .filter_map(|record| {
275 SerializableRecord::try_from(record.into_owned())
276 .map_err(|err| {
277 warn!("Failed to convert record to serializable record: {:?}", err);
278 })
279 .ok()
280 })
281 .collect();
282
283 let persistent_storage = self.persistent_storage.clone();
285 let record_delta = Arc::clone(&self.record_delta);
286 tokio::spawn(async move {
287 debug!("Saving DHT to persistent storage");
288
289 match timeout(
291 Duration::from_secs(10),
292 persistent_storage.save(serializable_records),
293 )
294 .await
295 .map_err(|_| anyhow::anyhow!("save operation timed out"))
296 {
297 Ok(Ok(())) => {},
298 Ok(Err(error)) | Err(error) => {
299 warn!("Failed to save DHT to persistent storage: {error}");
300 },
301 };
302
303 record_delta.store(0, Ordering::Release);
305
306 drop(permit);
307
308 debug!("Saved DHT to persistent storage");
309 });
310
311 true
312 }
313
314 pub async fn restore_from_persistent_storage(&mut self) -> anyhow::Result<()> {
319 debug!("Restoring DHT from persistent storage");
320
321 let serializable_records = self
323 .persistent_storage
324 .load()
325 .await
326 .with_context(|| "Failed to read DHT from persistent storage")?;
327
328 for serializable_record in serializable_records {
330 match libp2p::kad::Record::try_from(serializable_record) {
332 Ok(record) => {
333 if record.expires.is_none() || record.expires.unwrap() < Instant::now() {
335 continue;
336 }
337
338 if let Err(err) = self.underlying_record_store.put(record) {
340 warn!(
341 "Failed to restore record from persistent storage: {:?}",
342 err
343 );
344 }
345 },
346 Err(err) => {
347 warn!("Failed to parse record from persistent storage: {:?}", err);
348 },
349 };
350 }
351
352 debug!("Restored DHT from persistent storage");
353
354 Ok(())
355 }
356}
357
358impl<R: RecordStore, D: DhtPersistentStorage> RecordStore for PersistentStore<R, D> {
360 type ProvidedIter<'a>
361 = R::ProvidedIter<'a>
362 where
363 R: 'a,
364 D: 'a;
365 type RecordsIter<'a>
366 = R::RecordsIter<'a>
367 where
368 R: 'a,
369 D: 'a;
370 delegate! {
372 to self.underlying_record_store {
373 fn add_provider(&mut self, record: libp2p::kad::ProviderRecord) -> libp2p::kad::store::Result<()>;
374 fn get(&self, k: &libp2p::kad::RecordKey) -> Option<std::borrow::Cow<'_, libp2p::kad::Record>>;
375 fn provided(&self) -> Self::ProvidedIter<'_>;
376 fn providers(&self, key: &libp2p::kad::RecordKey) -> Vec<libp2p::kad::ProviderRecord>;
377 fn records(&self) -> Self::RecordsIter<'_>;
378 fn remove_provider(&mut self, k: &libp2p::kad::RecordKey, p: &libp2p::PeerId);
379 }
380 }
381
382 fn put(&mut self, record: libp2p::kad::Record) -> Result<()> {
384 let result = self.underlying_record_store.put(record);
386
387 if result.is_ok() {
389 self.record_delta.fetch_add(1, Ordering::Relaxed);
391
392 if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
394 self.try_save_to_persistent_storage();
396 }
397 }
398
399 result
400 }
401
402 fn remove(&mut self, k: &libp2p::kad::RecordKey) {
404 self.underlying_record_store.remove(k);
406
407 self.record_delta.fetch_add(1, Ordering::Relaxed);
409
410 if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
412 self.try_save_to_persistent_storage();
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use libp2p::{
421 kad::{store::MemoryStore, RecordKey},
422 PeerId,
423 };
424 use tracing_subscriber::EnvFilter;
425
426 use super::*;
427
428 #[tokio::test]
429 async fn test_save_and_restore() {
430 let _ = tracing_subscriber::fmt()
432 .with_env_filter(EnvFilter::from_default_env())
433 .try_init();
434
435 let mut store = PersistentStore::new(
437 MemoryStore::new(PeerId::random()),
438 DhtFilePersistence::new("/tmp/test1.dht".to_string()),
439 10,
440 )
441 .await;
442
443 let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
445
446 let random_value = rand::random::<[u8; 16]>();
448
449 let mut record = libp2p::kad::Record::new(key.clone(), random_value.to_vec());
451
452 record.expires = Some(Instant::now() + Duration::from_secs(10));
454
455 store.put(record).expect("Failed to put record into store");
457
458 assert!(store.try_save_to_persistent_storage());
460
461 tokio::time::sleep(Duration::from_millis(100)).await;
463
464 let new_store = PersistentStore::new(
466 MemoryStore::new(PeerId::random()),
467 DhtFilePersistence::new("/tmp/test1.dht".to_string()),
468 10,
469 )
470 .await;
471
472 let restored_record = new_store
474 .get(&key)
475 .expect("Failed to get record from store");
476
477 assert_eq!(restored_record.value, random_value.to_vec());
479 }
480
481 #[tokio::test]
482 async fn test_record_delta() {
483 let _ = tracing_subscriber::fmt()
485 .with_env_filter(EnvFilter::from_default_env())
486 .try_init();
487
488 let mut store = PersistentStore::new(
490 MemoryStore::new(PeerId::random()),
491 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
492 10,
493 )
494 .await;
495
496 let mut keys = Vec::new();
497 let mut values = Vec::new();
498
499 for _ in 0..10 {
501 let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
503 let value = rand::random::<[u8; 16]>();
504
505 keys.push(key.clone());
506 values.push(value);
507
508 let mut record = libp2p::kad::Record::new(key, value.to_vec());
510
511 record.expires = Some(Instant::now() + Duration::from_secs(10));
513
514 store.put(record).expect("Failed to put record into store");
515 }
516
517 let new_store = PersistentStore::new(
519 MemoryStore::new(PeerId::random()),
520 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
521 10,
522 )
523 .await;
524
525 for key in &keys {
527 assert!(new_store.get(key).is_none());
528 }
529
530 let mut record = libp2p::kad::Record::new(keys[0].clone(), values[0].to_vec());
532
533 record.expires = Some(Instant::now() + Duration::from_secs(10));
535
536 store.put(record).expect("Failed to put record into store");
538
539 tokio::time::sleep(Duration::from_millis(100)).await;
541
542 let new_store = PersistentStore::new(
544 MemoryStore::new(PeerId::random()),
545 DhtFilePersistence::new("/tmp/test2.dht".to_string()),
546 10,
547 )
548 .await;
549
550 for (i, key) in keys.iter().enumerate() {
552 let restored_record = new_store.get(key).expect("Failed to get record from store");
553 assert_eq!(restored_record.value, values[i]);
554 }
555
556 assert_eq!(store.record_delta.load(Ordering::Relaxed), 0);
558 }
559
560 #[test]
561 fn test_approximate_instant() {
562 let expiry_future = Instant::now() + Duration::from_secs(10);
564
565 let approximate_expiry =
567 unix_seconds_to_instant(instant_to_unix_seconds(expiry_future).unwrap())
568 .unwrap()
569 .duration_since(Instant::now());
570
571 assert!(approximate_expiry >= Duration::from_secs(9));
573 assert!(approximate_expiry <= Duration::from_secs(11));
574
575 let expiry_past = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
577
578 let approximate_expiry =
580 unix_seconds_to_instant(instant_to_unix_seconds(expiry_past).unwrap()).unwrap();
581 let time_difference = approximate_expiry.elapsed();
582
583 assert!(time_difference >= Duration::from_secs(9));
585 assert!(time_difference <= Duration::from_secs(11));
586 }
587}