hotshot_libp2p_networking/network/behaviours/dht/store/
persistent.rs

1//! This file contains the `PersistentStore` struct, which is a wrapper around a `RecordStore`
2//! that occasionally saves the DHT to a persistent storage.
3
4use std::{
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Context;
13use async_trait::async_trait;
14use delegate::delegate;
15use libp2p::kad::store::{RecordStore, Result};
16use serde::{Deserialize, Serialize};
17use tokio::{sync::Semaphore, time::timeout};
18use tracing::{debug, warn};
19
20/// A trait that we use to save and load the DHT to a file on disk
21/// or other storage medium
22#[async_trait]
23pub trait DhtPersistentStorage: Send + Sync + 'static + Clone {
24    /// Save the DHT (as a list of serializable records) to the persistent storage
25    ///
26    /// # Errors
27    /// - If we fail to save the DHT to the persistent storage provider
28    async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()>;
29
30    /// Load the DHT (as a list of serializable records) from the persistent storage
31    ///
32    /// # Errors
33    /// - If we fail to load the DHT from the persistent storage provider
34    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>>;
35}
36
37/// A no-op `PersistentStorage` that does not persist the DHT
38#[derive(Clone)]
39pub struct DhtNoPersistence;
40
41#[async_trait]
42impl DhtPersistentStorage for DhtNoPersistence {
43    async fn save(&self, _records: Vec<SerializableRecord>) -> anyhow::Result<()> {
44        Ok(())
45    }
46
47    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
48        Ok(vec![])
49    }
50}
51
52#[async_trait]
53impl<D: DhtPersistentStorage> DhtPersistentStorage for Arc<D> {
54    async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
55        self.as_ref().save(records).await
56    }
57
58    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
59        self.as_ref().load().await
60    }
61}
62
63/// A `PersistentStorage` that persists the DHT to a file on disk. Used mostly for
64/// testing.
65#[derive(Clone)]
66pub struct DhtFilePersistence {
67    /// The path to the file on disk
68    path: String,
69}
70
71impl DhtFilePersistence {
72    /// Create a new `DhtFilePersistence` with the given path
73    #[must_use]
74    pub fn new(path: String) -> Self {
75        Self { path }
76    }
77}
78
79#[async_trait]
80impl DhtPersistentStorage for DhtFilePersistence {
81    /// Save the DHT to the file on disk
82    ///
83    /// # Errors
84    /// - If we fail to serialize the records
85    /// - If we fail to write the serialized records to the file
86    async fn save(&self, records: Vec<SerializableRecord>) -> anyhow::Result<()> {
87        // Bincode-serialize the records
88        let to_save =
89            bincode::serialize(&records).with_context(|| "Failed to serialize records")?;
90
91        // Write the serialized records to the file
92        std::fs::write(&self.path, to_save).with_context(|| "Failed to write records to file")?;
93
94        Ok(())
95    }
96
97    /// Load the DHT from the file on disk
98    ///
99    /// # Errors
100    /// - If we fail to read the file
101    /// - If we fail to deserialize the records
102    async fn load(&self) -> anyhow::Result<Vec<SerializableRecord>> {
103        // Read the contents of the file
104        let contents =
105            std::fs::read(&self.path).with_context(|| "Failed to read records from file")?;
106
107        // Deserialize the contents
108        let records: Vec<SerializableRecord> =
109            bincode::deserialize(&contents).with_context(|| "Failed to deserialize records")?;
110
111        Ok(records)
112    }
113}
114
115/// A `RecordStore` wrapper that occasionally saves the DHT to a persistent storage.
116pub struct PersistentStore<R: RecordStore, D: DhtPersistentStorage> {
117    /// The underlying record store
118    underlying_record_store: R,
119
120    /// The persistent storage
121    persistent_storage: D,
122
123    /// The semaphore for limiting the number of concurrent operations (to one)
124    semaphore: Arc<Semaphore>,
125
126    /// The maximum number of records that can be added to the store before the store is saved to the persistent storage
127    max_record_delta: u64,
128
129    /// The running delta between the records in the persistent storage and the records in the underlying store
130    record_delta: Arc<AtomicU64>,
131}
132
133/// A serializable version of a Libp2p `Record`
134#[derive(Serialize, Deserialize)]
135pub struct SerializableRecord {
136    /// The key of the record
137    pub key: libp2p::kad::RecordKey,
138    /// The value of the record
139    pub value: Vec<u8>,
140    /// The (original) publisher of the record.
141    pub publisher: Option<libp2p::PeerId>,
142    /// The record expiration time in seconds since the Unix epoch
143    ///
144    /// This is an approximation of the expiration time because we can't
145    /// serialize an `Instant` directly.
146    pub expires_unix_secs: Option<u64>,
147}
148
149/// Approximate an `Instant` to the number of seconds since the Unix epoch
150fn instant_to_unix_seconds(instant: Instant) -> anyhow::Result<u64> {
151    // Get the current instant and system time
152    let now_instant = Instant::now();
153    let now_system = SystemTime::now();
154
155    // Get the duration of time between the instant and now
156    if instant > now_instant {
157        Ok(now_system
158            .checked_add(instant - now_instant)
159            .with_context(|| "Overflow when approximating expiration time")?
160            .duration_since(UNIX_EPOCH)
161            .with_context(|| "Failed to get duration since Unix epoch")?
162            .as_secs())
163    } else {
164        Ok(now_system
165            .checked_sub(now_instant - instant)
166            .with_context(|| "Underflow when approximating expiration time")?
167            .duration_since(UNIX_EPOCH)
168            .with_context(|| "Failed to get duration since Unix epoch")?
169            .as_secs())
170    }
171}
172
173/// Convert a unix-second timestamp to an `Instant`
174fn unix_seconds_to_instant(unix_secs: u64) -> anyhow::Result<Instant> {
175    // Get the current instant and unix time
176    let now_instant = Instant::now();
177    let unix_secs_now = SystemTime::now()
178        .duration_since(UNIX_EPOCH)
179        .with_context(|| "Failed to get duration since Unix epoch")?
180        .as_secs();
181
182    if unix_secs > unix_secs_now {
183        // If the instant is in the future, add the duration to the current time
184        now_instant
185            .checked_add(Duration::from_secs(unix_secs - unix_secs_now))
186            .with_context(|| "Overflow when calculating future instant")
187    } else {
188        // If the instant is in the past, subtract the duration from the current time
189        now_instant
190            .checked_sub(Duration::from_secs(unix_secs_now - unix_secs))
191            .with_context(|| "Underflow when calculating past instant")
192    }
193}
194
195/// Allow conversion from a `libp2p::kad::Record` to a `SerializableRecord`
196impl TryFrom<libp2p::kad::Record> for SerializableRecord {
197    type Error = anyhow::Error;
198
199    fn try_from(record: libp2p::kad::Record) -> anyhow::Result<Self> {
200        Ok(SerializableRecord {
201            key: record.key,
202            value: record.value,
203            publisher: record.publisher,
204            expires_unix_secs: record.expires.map(instant_to_unix_seconds).transpose()?,
205        })
206    }
207}
208
209/// Allow conversion from a `SerializableRecord` to a `libp2p::kad::Record`
210impl TryFrom<SerializableRecord> for libp2p::kad::Record {
211    type Error = anyhow::Error;
212
213    fn try_from(record: SerializableRecord) -> anyhow::Result<Self> {
214        Ok(libp2p::kad::Record {
215            key: record.key,
216            value: record.value,
217            publisher: record.publisher,
218            expires: record
219                .expires_unix_secs
220                .map(unix_seconds_to_instant)
221                .transpose()?,
222        })
223    }
224}
225
226impl<R: RecordStore, D: DhtPersistentStorage> PersistentStore<R, D> {
227    /// Create a new `PersistentStore` with the given underlying store and path.
228    /// On creation, the DHT is restored from the persistent storage if possible.
229    ///
230    /// `max_record_delta` is the maximum number of records that can be added to the store before
231    /// the store is saved to the persistent storage.
232    pub async fn new(
233        underlying_record_store: R,
234        persistent_storage: D,
235        max_record_delta: u64,
236    ) -> Self {
237        // Create the new store
238        let mut store = PersistentStore {
239            underlying_record_store,
240            persistent_storage,
241            max_record_delta,
242            record_delta: Arc::new(AtomicU64::new(0)),
243            semaphore: Arc::new(Semaphore::new(1)),
244        };
245
246        // Try to restore the DHT from the persistent store. If it fails, warn and start with an empty store
247        if let Err(err) = store.restore_from_persistent_storage().await {
248            warn!(
249                "Failed to restore DHT from persistent storage: {err}. Starting with empty store",
250            );
251        }
252
253        // Return the new store
254        store
255    }
256
257    /// Try saving the DHT to persistent storage if a task is not already in progress.
258    ///
259    /// Returns `true` if the DHT was saved, `false` otherwise.
260    fn try_save_to_persistent_storage(&mut self) -> bool {
261        // Try to acquire the semaphore, warning if another save operation is already in progress
262        let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else {
263            warn!(
264                "Skipping DHT save to persistent storage - another save operation is already in \
265                 progress"
266            );
267            return false;
268        };
269
270        // Get all records and convert them to their serializable counterparts
271        let serializable_records: Vec<_> = self
272            .underlying_record_store
273            .records()
274            .filter_map(|record| {
275                SerializableRecord::try_from(record.into_owned())
276                    .map_err(|err| {
277                        warn!("Failed to convert record to serializable record: {:?}", err);
278                    })
279                    .ok()
280            })
281            .collect();
282
283        // Spawn a task to save the DHT to the persistent storage
284        let persistent_storage = self.persistent_storage.clone();
285        let record_delta = Arc::clone(&self.record_delta);
286        tokio::spawn(async move {
287            debug!("Saving DHT to persistent storage");
288
289            // Save the DHT to the persistent storage
290            match timeout(
291                Duration::from_secs(10),
292                persistent_storage.save(serializable_records),
293            )
294            .await
295            .map_err(|_| anyhow::anyhow!("save operation timed out"))
296            {
297                Ok(Ok(())) => {},
298                Ok(Err(error)) | Err(error) => {
299                    warn!("Failed to save DHT to persistent storage: {error}");
300                },
301            };
302
303            // Reset the record delta
304            record_delta.store(0, Ordering::Release);
305
306            drop(permit);
307
308            debug!("Saved DHT to persistent storage");
309        });
310
311        true
312    }
313
314    /// Attempt to restore the DHT to the underlying store from the persistent storage
315    ///
316    /// # Errors
317    /// - If we fail to load from the persistent storage
318    pub async fn restore_from_persistent_storage(&mut self) -> anyhow::Result<()> {
319        debug!("Restoring DHT from persistent storage");
320
321        // Read the contents of the persistent store
322        let serializable_records = self
323            .persistent_storage
324            .load()
325            .await
326            .with_context(|| "Failed to read DHT from persistent storage")?;
327
328        // Put all records into the new store
329        for serializable_record in serializable_records {
330            // Convert the serializable record back to a `libp2p::kad::Record`
331            match libp2p::kad::Record::try_from(serializable_record) {
332                Ok(record) => {
333                    // If the record doesn't have an expiration time, or has expired, skip it
334                    if record.expires.is_none() || record.expires.unwrap() < Instant::now() {
335                        continue;
336                    }
337
338                    // Put the record into the new store
339                    if let Err(err) = self.underlying_record_store.put(record) {
340                        warn!(
341                            "Failed to restore record from persistent storage: {:?}",
342                            err
343                        );
344                    }
345                },
346                Err(err) => {
347                    warn!("Failed to parse record from persistent storage: {:?}", err);
348                },
349            };
350        }
351
352        debug!("Restored DHT from persistent storage");
353
354        Ok(())
355    }
356}
357
358/// Implement the `RecordStore` trait for `PersistentStore`
359impl<R: RecordStore, D: DhtPersistentStorage> RecordStore for PersistentStore<R, D> {
360    type ProvidedIter<'a>
361        = R::ProvidedIter<'a>
362    where
363        R: 'a,
364        D: 'a;
365    type RecordsIter<'a>
366        = R::RecordsIter<'a>
367    where
368        R: 'a,
369        D: 'a;
370    // Delegate all `RecordStore` methods except `put` to the inner store
371    delegate! {
372        to self.underlying_record_store {
373            fn add_provider(&mut self, record: libp2p::kad::ProviderRecord) -> libp2p::kad::store::Result<()>;
374            fn get(&self, k: &libp2p::kad::RecordKey) -> Option<std::borrow::Cow<'_, libp2p::kad::Record>>;
375            fn provided(&self) -> Self::ProvidedIter<'_>;
376            fn providers(&self, key: &libp2p::kad::RecordKey) -> Vec<libp2p::kad::ProviderRecord>;
377            fn records(&self) -> Self::RecordsIter<'_>;
378            fn remove_provider(&mut self, k: &libp2p::kad::RecordKey, p: &libp2p::PeerId);
379        }
380    }
381
382    /// Override the `put` method to potentially sync the DHT to the persistent store
383    fn put(&mut self, record: libp2p::kad::Record) -> Result<()> {
384        // Try to write to the underlying store
385        let result = self.underlying_record_store.put(record);
386
387        // If the record was successfully written,
388        if result.is_ok() {
389            // Update the record delta
390            self.record_delta.fetch_add(1, Ordering::Relaxed);
391
392            // Check if it's above the maximum record delta
393            if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
394                // Try to save the DHT to persistent storage
395                self.try_save_to_persistent_storage();
396            }
397        }
398
399        result
400    }
401
402    /// Overwrite the `remove` method to potentially sync the DHT to the persistent store
403    fn remove(&mut self, k: &libp2p::kad::RecordKey) {
404        // Remove the record from the underlying store
405        self.underlying_record_store.remove(k);
406
407        // Update the record delta
408        self.record_delta.fetch_add(1, Ordering::Relaxed);
409
410        // Check if it's above the maximum record delta
411        if self.record_delta.load(Ordering::Relaxed) > self.max_record_delta {
412            // Try to save the DHT to persistent storage
413            self.try_save_to_persistent_storage();
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use libp2p::{
421        kad::{store::MemoryStore, RecordKey},
422        PeerId,
423    };
424    use tracing_subscriber::EnvFilter;
425
426    use super::*;
427
428    #[tokio::test]
429    async fn test_save_and_restore() {
430        // Try initializing tracing
431        let _ = tracing_subscriber::fmt()
432            .with_env_filter(EnvFilter::from_default_env())
433            .try_init();
434
435        // Create a test store
436        let mut store = PersistentStore::new(
437            MemoryStore::new(PeerId::random()),
438            DhtFilePersistence::new("/tmp/test1.dht".to_string()),
439            10,
440        )
441        .await;
442
443        // The key is a random 16-byte array
444        let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
445
446        // The value is a random 16-byte array
447        let random_value = rand::random::<[u8; 16]>();
448
449        // Create a new record
450        let mut record = libp2p::kad::Record::new(key.clone(), random_value.to_vec());
451
452        // Set the expiry time to 10 seconds in the future
453        record.expires = Some(Instant::now() + Duration::from_secs(10));
454
455        // Put a record into the store
456        store.put(record).expect("Failed to put record into store");
457
458        // Try to save the store to a persistent storage
459        assert!(store.try_save_to_persistent_storage());
460
461        // Wait a bit for the save to complete
462        tokio::time::sleep(Duration::from_millis(100)).await;
463
464        // Create a new store from the persistent storage
465        let new_store = PersistentStore::new(
466            MemoryStore::new(PeerId::random()),
467            DhtFilePersistence::new("/tmp/test1.dht".to_string()),
468            10,
469        )
470        .await;
471
472        // Check that the new store has the record
473        let restored_record = new_store
474            .get(&key)
475            .expect("Failed to get record from store");
476
477        // Check that the restored record has the same value as the original record
478        assert_eq!(restored_record.value, random_value.to_vec());
479    }
480
481    #[tokio::test]
482    async fn test_record_delta() {
483        // Try initializing tracing
484        let _ = tracing_subscriber::fmt()
485            .with_env_filter(EnvFilter::from_default_env())
486            .try_init();
487
488        // Create a test store
489        let mut store = PersistentStore::new(
490            MemoryStore::new(PeerId::random()),
491            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
492            10,
493        )
494        .await;
495
496        let mut keys = Vec::new();
497        let mut values = Vec::new();
498
499        // Put 10 records into the store
500        for _ in 0..10 {
501            // Create a random key and value
502            let key = RecordKey::new(&rand::random::<[u8; 16]>().to_vec());
503            let value = rand::random::<[u8; 16]>();
504
505            keys.push(key.clone());
506            values.push(value);
507
508            // Create a new record
509            let mut record = libp2p::kad::Record::new(key, value.to_vec());
510
511            // Set the expiry time to 10 seconds in the future
512            record.expires = Some(Instant::now() + Duration::from_secs(10));
513
514            store.put(record).expect("Failed to put record into store");
515        }
516
517        // Create a new store from the allegedly unpersisted DHT
518        let new_store = PersistentStore::new(
519            MemoryStore::new(PeerId::random()),
520            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
521            10,
522        )
523        .await;
524
525        // Check that the new store has none of the records
526        for key in &keys {
527            assert!(new_store.get(key).is_none());
528        }
529
530        // Create a new record
531        let mut record = libp2p::kad::Record::new(keys[0].clone(), values[0].to_vec());
532
533        // Set the expiry time to 10 seconds in the future
534        record.expires = Some(Instant::now() + Duration::from_secs(10));
535
536        // Store one more record into the new store
537        store.put(record).expect("Failed to put record into store");
538
539        // Wait a bit for the save to complete
540        tokio::time::sleep(Duration::from_millis(100)).await;
541
542        // Create a new store from the allegedly saved DHT
543        let new_store = PersistentStore::new(
544            MemoryStore::new(PeerId::random()),
545            DhtFilePersistence::new("/tmp/test2.dht".to_string()),
546            10,
547        )
548        .await;
549
550        // Check that the new store has all of the records
551        for (i, key) in keys.iter().enumerate() {
552            let restored_record = new_store.get(key).expect("Failed to get record from store");
553            assert_eq!(restored_record.value, values[i]);
554        }
555
556        // Check that the record delta is 0
557        assert_eq!(store.record_delta.load(Ordering::Relaxed), 0);
558    }
559
560    #[test]
561    fn test_approximate_instant() {
562        // Create an expiry time in the future
563        let expiry_future = Instant::now() + Duration::from_secs(10);
564
565        // Approximate the expiry time
566        let approximate_expiry =
567            unix_seconds_to_instant(instant_to_unix_seconds(expiry_future).unwrap())
568                .unwrap()
569                .duration_since(Instant::now());
570
571        // Make sure it's close to 10 seconds in the future
572        assert!(approximate_expiry >= Duration::from_secs(9));
573        assert!(approximate_expiry <= Duration::from_secs(11));
574
575        // Create an expiry time in the past
576        let expiry_past = Instant::now().checked_sub(Duration::from_secs(10)).unwrap();
577
578        // Approximate the expiry time
579        let approximate_expiry =
580            unix_seconds_to_instant(instant_to_unix_seconds(expiry_past).unwrap()).unwrap();
581        let time_difference = approximate_expiry.elapsed();
582
583        // Make sure it's close to 10 seconds in the past
584        assert!(time_difference >= Duration::from_secs(9));
585        assert!(time_difference <= Duration::from_secs(11));
586    }
587}