hotshot_libp2p_networking/network/behaviours/dht/store/
persistent.rs

1//! This file contains the `PersistentStore` struct, which is a wrapper around a `RecordStore`
2//! that occasionally saves the DHT to a persistent storage.
3
4use std::{
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Context;
13use async_trait::async_trait;
14use delegate::delegate;
15use libp2p::kad::store::{RecordStore, Result};
16use serde::{Deserialize, Serialize};
17use tokio::{sync::Semaphore, time::timeout};
18use tracing::{debug, warn};
19
20/// A trait that we use to save and load the DHT to a file on disk
21/// or other storage medium
22#[async_trait]
23pub trait DhtPersistentStorage: Send + Sync + 'static + Clone {
24    /// Save the DHT (as a list of serializable records) to the persistent storage
25    ///
26    /// # Errors
27    /// - If we fail to save the DHT to the persistent storage provider
28    async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()>;
29
30    /// Load the DHT (as a list of serializable records) from the persistent storage
31    ///
32    /// # Errors
33    /// - If we fail to load the DHT from the persistent storage provider
34    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>>;
35}
36
37/// A no-op `PersistentStorage` that does not persist the DHT
38#[derive(Clone)]
39pub struct DhtNoPersistence;
40
41#[async_trait]
42impl DhtPersistentStorage for DhtNoPersistence {
43    async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()> {
44        Ok(())
45    }
46
47    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
48        Ok(vec![])
49    }
50}
51
52#[async_trait]
53impl<D: DhtPersistentStorage> DhtPersistentStorage for Arc<D> {
54    async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
55        self.as_ref().save(records).await
56    }
57
58    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
59        self.as_ref().load().await
60    }
61}
62
63/// A `PersistentStorage` that persists the DHT to a file on disk. Used mostly for
64/// testing.
65#[derive(Clone)]
66pub struct DhtFilePersistence {
67    /// The path to the file on disk
68    path: String,
69}
70
71impl DhtFilePersistence {
72    /// Create a new `DhtFilePersistence` with the given path
73    #[must_use]
74    pub fn new(path: String) -> Self {
75        Self { path }
76    }
77}
78
79#[async_trait]
80impl DhtPersistentStorage for DhtFilePersistence {
81    /// Save the DHT to the file on disk
82    ///
83    /// # Errors
84    /// - If we fail to serialize the records
85    /// - If we fail to write the serialized records to the file
86    async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
87        // Bincode-serialize the records
88        let to_save =
89            bincode::serialize(&records).with_context(|| "Failed to serialize records")?;
90
91        // Write the serialized records to the file
92        std::fs::write(&self.path, to_save).with_context(|| "Failed to write records to file")?;
93
94        Ok(())
95    }
96
97    /// Load the DHT from the file on disk
98    ///
99    /// # Errors
100    /// - If we fail to read the file
101    /// - If we fail to deserialize the records
102    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
103        // Read the contents of the file
104        let contents =
105            std::fs::read(&self.path).with_context(|| "Failed to read records from file")?;
106
107        // Deserialize the contents
108        let records: Vec<SerializableRecord> =
109            bincode::deserialize(&contents).with_context(|| "Failed to deserialize records")?;
110
111        Ok(records)
112    }
113}
114
115/// A `RecordStore` wrapper that occasionally saves the DHT to a persistent storage.
116pub struct PersistentStore<R: RecordStore, D: DhtPersistentStorage> {
117    /// The underlying record store
118    underlying_record_store: R,
119
120    /// The persistent storage
121    persistent_storage: D,
122
123    /// The semaphore for limiting the number of concurrent operations (to one)
124    semaphore: Arc<Semaphore>,
125
126    /// The maximum number of records that can be added to the store before the store is saved to the persistent storage
127    max_record_delta: u64,
128
129    /// The running delta between the records in the persistent storage and the records in the underlying store
130    record_delta: Arc<AtomicU64>,
131}
132
133/// A serializable version of a Libp2p `Record`
134#[derive(Serialize, Deserialize)]
135pub struct SerializableRecord {
136    /// The key of the record
137    pub key: libp2p::kad::RecordKey,
138    /// The value of the record
139    pub value: Vec<u8>,
140    /// The (original) publisher of the record.
141    pub publisher: Option<libp2p::PeerId>,
142    /// The record expiration time in seconds since the Unix epoch
143    ///
144    /// This is an approximation of the expiration time because we can't
145    /// serialize an `Instant` directly.
146    pub expires_unix_secs: Option<u64>,
147}
148
149/// Approximate an `Instant` to the number of seconds since the Unix epoch
150fn instant_to_unix_seconds(instant: Instant) -> anyhow::Result<u64> {
151    // Get the current instant and system time
152    let now_instant = Instant::now();
153    let now_system = SystemTime::now();
154
155    // Get the duration of time between the instant and now
156    if instant > now_instant {
157        Ok(now_system
158            .checked_add(instant - now_instant)
159            .with_context(|| "Overflow when approximating expiration time")?
160            .duration_since(UNIX_EPOCH)
161            .with_context(|| "Failed to get duration since Unix epoch")?
162            .as_secs())
163    } else {
164        Ok(now_system
165            .checked_sub(now_instant - instant)
166            .with_context(|| "Underflow when approximating expiration time")?
167            .duration_since(UNIX_EPOCH)
168            .with_context(|| "Failed to get duration since Unix epoch")?
169            .as_secs())
170    }
171}
172
173/// Convert a unix-second timestamp to an `Instant`
174fn unix_seconds_to_instant(unix_secs: u64) -> anyhow::Result<Instant> {
175    // Get the current instant and unix time
176    let now_instant = Instant::now();
177    let unix_secs_now = SystemTime::now()
178        .duration_since(UNIX_EPOCH)
179        .with_context(|| "Failed to get duration since Unix epoch")?
180        .as_secs();
181
182    if unix_secs > unix_secs_now {
183        // If the instant is in the future, add the duration to the current time
184        now_instant
185            .checked_add(Duration::from_secs(unix_secs - unix_secs_now))
186            .with_context(|| "Overflow when calculating future instant")
187    } else {
188        // If the instant is in the past, subtract the duration from the current time
189        now_instant
190            .checked_sub(Duration::from_secs(unix_secs_now - unix_secs))
191            .with_context(|| "Underflow when calculating past instant")
192    }
193}
194
195/// Allow conversion from a `libp2p::kad::Record` to a `SerializableRecord`
196impl TryFrom<libp2p::kad::Record> for SerializableRecord {
197    type Error = anyhow::Error;
198
199    fn try_from(record: libp2p::kad::Record) -> anyhow::Result<Self> {
200        Ok(SerializableRecord {
201            key: record.key,
202            value: record.value,
203            publisher: record.publisher,
204            expires_unix_secs: record.expires.map(instant_to_unix_seconds).transpose()?,
205        })
206    }
207}
208
209/// Allow conversion from a `SerializableRecord` to a `libp2p::kad::Record`
210impl TryFrom<SerializableRecord> for libp2p::kad::Record {
211    type Error = anyhow::Error;
212
213    fn try_from(record: SerializableRecord) -> anyhow::Result<Self> {
214        Ok(libp2p::kad::Record {
215            key: record.key,
216            value: record.value,
217            publisher: record.publisher,
218            expires: record
219                .expires_unix_secs
220                .map(unix_seconds_to_instant)
221                .transpose()?,
222        })
223    }
224}
225
226impl<R: RecordStore, D: DhtPersistentStorage> PersistentStore<R, D> {
227    /// Create a new `PersistentStore` with the given underlying store and path.
228    /// On creation, the DHT is restored from the persistent storage if possible.
229    ///
230    /// `max_record_delta` is the maximum number of records that can be added to the store before
231    /// the store is saved to the persistent storage.
232    pub async fn new(
233        underlying_record_store: R,
234        persistent_storage: D,
235        max_record_delta: u64,
236    ) -> Self {
237        // Create the new store
238        let mut store = PersistentStore {
239            underlying_record_store,
240            persistent_storage,
241            max_record_delta,
242            record_delta: Arc::new(AtomicU64::new(0)),
243            semaphore: Arc::new(Semaphore::new(1)),
244        };
245
246        // Try to restore the DHT from the persistent store. If it fails, warn and start with an empty store
247        if let Err(err) = store.restore_from_persistent_storage().await {
248            warn!(
249                "Failed to restore DHT from persistent storage: {:?}. Starting with empty store",
250                err
251            );
252        }
253
254        // Return the new store
255        store
256    }
257
258    /// Try saving the DHT to persistent storage if a task is not already in progress.
259    ///
260    /// Returns `true` if the DHT was saved, `false` otherwise.
261    fn try_save_to_persistent_storage(&mut self) -> bool {
262        // Try to acquire the semaphore, warning if another save operation is already in progress
263        let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else {
264            warn!("Skipping DHT save to persistent storage - another save operation is already in progress");
265            return false;
266        };
267
268        // Get all records and convert them to their serializable counterparts
269        let serializable_records: Vec<_> = self
270            .underlying_record_store
271            .records()
272            .filter_map(|record| {
273                SerializableRecord::try_from(record.into_owned())
274                    .map_err(|err| {
275                        warn!("Failed to convert record to serializable record: {:?}", err);
276                    })
277                    .ok()
278            })
279            .collect();
280
281        // Spawn a task to save the DHT to the persistent storage
282        let persistent_storage = self.persistent_storage.clone();
283        let record_delta = Arc::clone(&self.record_delta);
284        tokio::spawn(async move {
285            debug!("Saving DHT to persistent storage");
286
287            // Save the DHT to the persistent storage
288            match timeout(
289                Duration::from_secs(10),
290                persistent_storage.save(serializable_records),
291            )
292            .await
293            .map_err(|_| anyhow::anyhow!("save operation timed out"))
294            {
295                Ok(Ok(())) => {},
296                Ok(Err(error)) | Err(error) => {
297                    warn!("Failed to save DHT to persistent storage: {error}");
298                },
299            };
300
301            // Reset the record delta
302            record_delta.store(0, Ordering::Release);
303
304            drop(permit);
305
306            debug!("Saved DHT to persistent storage");
307        });
308
309        true
310    }
311
312    /// Attempt to restore the DHT to the underlying store from the persistent storage
313    ///
314    /// # Errors
315    /// - If we fail to load from the persistent storage
316    pub async fn restore_from_persistent_storage(&mut self) -> anyhow::Result<()> {
317        debug!("Restoring DHT from persistent storage");
318
319        // Read the contents of the persistent store
320        let serializable_records = self
321            .persistent_storage
322            .load()
323            .await
324            .with_context(|| "Failed to read DHT from persistent storage")?;
325
326        // Put all records into the new store
327        for serializable_record in serializable_records {
328            // Convert the serializable record back to a `libp2p::kad::Record`
329            match libp2p::kad::Record::try_from(serializable_record) {
330                Ok(record) => {
331                    // If the record doesn't have an expiration time, or has expired, skip it
332                    if record.expires.is_none() || record.expires.unwrap() < Instant::now() {
333                        continue;
334                    }
335
336                    // Put the record into the new store
337                    if let Err(err) = self.underlying_record_store.put(record) {
338                        warn!(
339                            "Failed to restore record from persistent storage: {:?}",
340                            err
341                        );
342                    }
343                },
344                Err(err) => {
345                    warn!("Failed to parse record from persistent storage: {:?}", err);
346                },
347            };
348        }
349
350        debug!("Restored DHT from persistent storage");
351
352        Ok(())
353    }
354}
355
356/// Implement the `RecordStore` trait for `PersistentStore`
357impl<R: RecordStore, D: DhtPersistentStorage> RecordStore for PersistentStore<R, D> {
358    type ProvidedIter<'a>
359        = R::ProvidedIter<'a>
360    where
361        R: 'a,
362        D: 'a;
363    type RecordsIter<'a>
364        = R::RecordsIter<'a>
365    where
366        R: 'a,
367        D: 'a;
368    // Delegate all `RecordStore` methods except `put` to the inner store
369    delegate! {
370        to self.underlying_record_store {
371            fn add_provider(&mut self, record: libp2p::kad::ProviderRecord) -> libp2p::kad::store::Result<()>;
372            fn get(&self, k: &libp2p::kad::RecordKey) -> Option<std::borrow::Cow<'_, libp2p::kad::Record>>;
373            fn provided(&self) -> Self::ProvidedIter<'_>;
374            fn providers(&self, key: &libp2p::kad::RecordKey) -> Vec<libp2p::kad::ProviderRecord>;
375            fn records(&self) -> Self::RecordsIter<'_>;
376            fn remove_provider(&mut self, k: &libp2p::kad::RecordKey, p: &libp2p::PeerId);
377        }
378    }
379
380    /// Override the `put` method to potentially sync the DHT to the persistent store
381    fn put(&mut self, record: libp2p::kad::Record) -> Result<()> {
382        // Try to write to the underlying store
383        let result = self.underlying_record_store.put(record);
384
385        // If the record was successfully written,
386        if result.is_ok() {
387            // Update the record delta
388            self.record_delta.fetch_add(1, Ordering::Relaxed);
389
390            // Check if it's above the maximum record delta
391            if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
392                // Try to save the DHT to persistent storage
393                self.try_save_to_persistent_storage();
394            }
395        }
396
397        result
398    }
399
400    /// Overwrite the `remove` method to potentially sync the DHT to the persistent store
401    fn remove(&mut self, k: &libp2p::kad::RecordKey) {
402        // Remove the record from the underlying store
403        self.underlying_record_store.remove(k);
404
405        // Update the record delta
406        self.record_delta.fetch_add(1, Ordering::Relaxed);
407
408        // Check if it's above the maximum record delta
409        if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
410            // Try to save the DHT to persistent storage
411            self.try_save_to_persistent_storage();
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use libp2p::{
419        kad::{store::MemoryStore, RecordKey},
420        PeerId,
421    };
422    use tracing_subscriber::EnvFilter;
423
424    use super::*;
425
426    #[tokio::test]
427    async fn test_save_and_restore() {
428        // Try initializing tracing
429        let _ = tracing_subscriber::fmt()
430            .with_env_filter(EnvFilter::from_default_env())
431            .try_init();
432
433        // Create a test store
434        let mut store = PersistentStore::new(
435            MemoryStore::new(PeerId::random()),
436            DhtFilePersistence::new("/tmp/test1.dht".to_string()),
437            10,
438        )
439        .await;
440
441        // The key is a random 16-byte array
442        let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
443
444        // The value is a random 16-byte array
445        let random_value = rand::random::<[u8; 16]>();
446
447        // Create a new record
448        let mut record = libp2p::kad::Record::new(key.clone(), random_value.to_vec());
449
450        // Set the expiry time to 10 seconds in the future
451        record.expires = Some(Instant::now() + Duration::from_secs(10));
452
453        // Put a record into the store
454        store.put(record).expect("Failed to put record into store");
455
456        // Try to save the store to a persistent storage
457        assert!(store.try_save_to_persistent_storage());
458
459        // Wait a bit for the save to complete
460        tokio::time::sleep(Duration::from_millis(100)).await;
461
462        // Create a new store from the persistent storage
463        let new_store = PersistentStore::new(
464            MemoryStore::new(PeerId::random()),
465            DhtFilePersistence::new("/tmp/test1.dht".to_string()),
466            10,
467        )
468        .await;
469
470        // Check that the new store has the record
471        let restored_record = new_store
472            .get(&key)
473            .expect("Failed to get record from store");
474
475        // Check that the restored record has the same value as the original record
476        assert_eq!(restored_record.value, random_value.to_vec());
477    }
478
479    #[tokio::test]
480    async fn test_record_delta() {
481        // Try initializing tracing
482        let _ = tracing_subscriber::fmt()
483            .with_env_filter(EnvFilter::from_default_env())
484            .try_init();
485
486        // Create a test store
487        let mut store = PersistentStore::new(
488            MemoryStore::new(PeerId::random()),
489            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
490            10,
491        )
492        .await;
493
494        let mut keys = Vec::new();
495        let mut values = Vec::new();
496
497        // Put 10 records into the store
498        for _ in 0..10 {
499            // Create a random key and value
500            let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
501            let value = rand::random::<[u8; 16]>();
502
503            keys.push(key.clone());
504            values.push(value);
505
506            // Create a new record
507            let mut record = libp2p::kad::Record::new(key, value.to_vec());
508
509            // Set the expiry time to 10 seconds in the future
510            record.expires = Some(Instant::now() + Duration::from_secs(10));
511
512            store.put(record).expect("Failed to put record into store");
513        }
514
515        // Create a new store from the allegedly unpersisted DHT
516        let new_store = PersistentStore::new(
517            MemoryStore::new(PeerId::random()),
518            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
519            10,
520        )
521        .await;
522
523        // Check that the new store has none of the records
524        for key in &keys {
525            assert!(new_store.get(key).is_none());
526        }
527
528        // Create a new record
529        let mut record = libp2p::kad::Record::new(keys[0].clone(), values[0].to_vec());
530
531        // Set the expiry time to 10 seconds in the future
532        record.expires = Some(Instant::now() + Duration::from_secs(10));
533
534        // Store one more record into the new store
535        store.put(record).expect("Failed to put record into store");
536
537        // Wait a bit for the save to complete
538        tokio::time::sleep(Duration::from_millis(100)).await;
539
540        // Create a new store from the allegedly saved DHT
541        let new_store = PersistentStore::new(
542            MemoryStore::new(PeerId::random()),
543            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
544            10,
545        )
546        .await;
547
548        // Check that the new store has all of the records
549        for (i, key) in keys.iter().enumerate() {
550            let restored_record = new_store.get(key).expect("Failed to get record from store");
551            assert_eq!(restored_record.value, values[i]);
552        }
553
554        // Check that the record delta is 0
555        assert_eq!(store.record_delta.load(Ordering::Relaxed), 0);
556    }
557
558    #[test]
559    fn test_approximate_instant() {
560        // Create an expiry time in the future
561        let expiry_future = Instant::now() + Duration::from_secs(10);
562
563        // Approximate the expiry time
564        let approximate_expiry =
565            unix_seconds_to_instant(instant_to_unix_seconds(expiry_future).unwrap())
566                .unwrap()
567                .duration_since(Instant::now());
568
569        // Make sure it's close to 10 seconds in the future
570        assert!(approximate_expiry >= Duration::from_secs(9));
571        assert!(approximate_expiry <= Duration::from_secs(11));
572
573        // Create an expiry time in the past
574        let expiry_past = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
575
576        // Approximate the expiry time
577        let approximate_expiry =
578            unix_seconds_to_instant(instant_to_unix_seconds(expiry_past).unwrap()).unwrap();
579        let time_difference = approximate_expiry.elapsed();
580
581        // Make sure it's close to 10 seconds in the past
582        assert!(time_difference >= Duration::from_secs(9));
583        assert!(time_difference <= Duration::from_secs(11));
584    }
585}