Your resource for web content, online publishing
and the distribution of digital products.
S M T W T F S
 
 
 
 
 
1
 
2
 
3
 
4
 
5
 
6
 
7
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
 
28
 
29
 
30
 

How to Build a Database From Scratch: Understanding LSM Trees and Storage Engines (Part 1)

DATE POSTED:October 28, 2024

Learn core database concepts by implementing a Python key-value store with crash recovery and efficient writes.

\ In this tutorial, we’ll build a simple but functional database from scratch with Python. Through this hands-on project, we’ll explore core database concepts like Write-Ahead Logging (WAL), Sorted String Tables (SSTables), Log-Structured Merge (LSM) trees, and other optimization techniques. By the end, you’ll have a deeper understanding of how real databases work under the hood.

Why Build a Database?

Before diving into implementation, let’s understand what a database is and why we’re building one. Building a database from scratch is not just an educational exercise — it helps us understand the tradeoffs and decisions that go into database design, making us better at choosing and using databases in our own applications. Whether you’re using MongoDB, PostgreSQL, or any other database, the concepts we’ll explore form the foundation of their implementations.

\ A database is like a digital filing cabinet that helps us:

  • Store data persistently (meaning it survives even when your computer restarts)
  • Find data quickly when we need it
  • Keep data organized and consistent
  • Handle multiple users accessing data at the same time
  • Recover from crashes without losing information

\ In this tutorial, we’ll build a simple database that stores key-value pairs. Think of it like a Python dictionary that saves to disk.

\ For example:

# Regular Python dictionary (loses data when program ends) my_dict = { "user:1": {"name": "Alice", "age": 30}, "user:2": {"name": "Bob", "age": 25} } # Our database (keeps data even after program ends) our_db.set("user:1", {"name": "Alice", "age": 30}) our_db.get("user:1") # Returns: {"name": "Alice", "age": 30}

\ You might wonder: “Why not just save a Python dictionary to a file?” Let’s start there to understand the problems we’ll need to solve.

import os import pickle from typing import Any, Dict, Optional class SimpleStore: def __init__(self, filename: str): self.filename = filename self.data: Dict[str, Any] = {} self._load() def _load(self): """Load data from disk if it exists""" if os.path.exists(self.filename): with open(self.filename, 'rb') as f: self.data = pickle.load(f) def _save(self): """Save data to disk""" with open(self.filename, 'wb') as f: pickle.dump(self.data, f) def set(self, key: str, value: Any): self.data[key] = value self._save() # Write to disk immediately def get(self, key: str) -> Optional[Any]: return self.data.get(key)

Let’s break down what this code does:

  • init(self, filename): Creates a new database using a file to store data. Filename is where we'll save our data. self.data is our in-memory dictionary.

\

  • _load(self): Reads saved data from disk when we start. Uses Python’s pickle module to convert saved bytes back into Python objects The underscore (_) means it’s an internal method not meant to be called directly.

\

  • _save(self): Writes all data to disk. Uses Pickle to convert Python objects into bytes. Called every time we changed data.

\

  • set(self, key, value) and get(self, key): Works like a dictionary’s [] operator, set saves to disk, and immediately gets returns. None if the key doesn't exist.

\ This simple implementation has several problems:

  • Crask risk:
# If the program crashes here, we lose data: db.data["key"] = "value" # Changed in memory # ... crash before _save() ...
  • Performance issue:
# This will be very slow because it writes the entire database for each set for i in range(1000): db.set(f"key:{i}", f"value:{i}") # Writes ALL data 1000 times!
  • Memory limitation:
# All data must fit in RAM: huge_data = {"key": "x" * 1000000000} # 1GB of data db.set("huge", huge_data) # Might run out of memory!
  • Concurrency issue:
# If two programs do this at once: db.set("counter", db.get("counter") + 1) # They might both read "5" and both write "6" # Instead of one writing "6" and one writing "7" Write-Ahead Logging (WAL) — Making It Durable

First, to ensure data persistence and recovery capabilities, we will use Write-Ahead Logging. Write-ahead logging is like keeping a diary of everything you’re going to do before you do it. If something goes wrong halfway through, you can look at your diary and finish the job. WAL is a reliability mechanism that records all changes before they are applied to the database.

\ Example:

# WAL entries are like diary entries: Entry 1: "At 2024-03-20 14:30:15, set user:1 to {"name": "Alice"}" Entry 2: "At 2024-03-20 14:30:16, set user:2 to {"name": "Bob"}" Entry 3: "At 2024-03-20 14:30:17, delete user:1"

\

class WALEntry: def __init__(self, operation: str, key: str, value: Any): self.timestamp = datetime.utcnow().isoformat() self.operation = operation # 'set' or 'delete' self.key = key self.value = value def serialize(self) -> str: """Convert the entry to a string for storage""" return json.dumps({ 'timestamp': self.timestamp, 'operation': self.operation, 'key': self.key, 'value': self.value })

Let’s understand what each part means:

  • timestamp: When the operation happened
# Example timestamp "2024-03-20T14:30:15.123456" # ISO format: readable and sortable
  • operation: What we're doing
# Example operations {"operation": "set", "key": "user:1", "value": {"name": "Alice"}} {"operation": "delete", "key": "user:1", "value": null}
  • serialize: Converts to string for storage
# Instead of binary pickle format, we use JSON for readability: { "timestamp": "2024-03-20T14:30:15.123456", "operation": "set", "key": "user:1", "value": {"name": "Alice"} }

\

The WAL Store Implementation class DatabaseError(Exception): """Base class for database exceptions""" pass class WALStore: def __init__(self, data_file: str, wal_file: str): self.data_file = data_file self.wal_file = wal_file self.data: Dict[str, Any] = {} self._recover() def _append_wal(self, entry: WALEntry): """Write an entry to the log file""" try: with open(self.wal_file, "a") as f: f.write(entry.serialize() + "\n") f.flush() # Ensure it's written to disk os.fsync(f.fileno()) # Force disk write except IOError as e: raise DatabaseError(f"Failed to write to WAL: {e}") def _recover(self): """Rebuild database state from WAL if needed""" try: # First load the last checkpoint if os.path.exists(self.data_file): with open(self.data_file, "rb") as f: self.data = pickle.load(f) # Then replay any additional changes from WAL if os.path.exists(self.wal_file): with open(self.wal_file, "r") as f: for line in f: if line.strip(): # Skip empty lines entry = json.loads(line) if entry["operation"] == "set": self.data[entry["key"]] = entry["value"] elif entry["operation"] == "delete": self.data.pop(entry["key"], None) except (IOError, json.JSONDecodeError, pickle.PickleError) as e: raise DatabaseError(f"Recovery failed: {e}") def set(self, key: str, value: Any): """Set a key-value pair with WAL""" entry = WALEntry("set", key, value) self._append_wal(entry) self.data[key] = value def delete(self, key: str): """Delete a key with WAL""" entry = WALEntry("delete", key, None) self._append_wal(entry) self.data.pop(key, None) def checkpoint(self): """Create a checkpoint of current state""" temp_file = f"{self.data_file}.tmp" try: # Write to temporary file first with open(temp_file, "wb") as f: pickle.dump(self.data, f) f.flush() os.fsync(f.fileno()) # Atomically replace old file shutil.move(temp_file, self.data_file) # Clear WAL - just truncate instead of opening in 'w' mode if os.path.exists(self.wal_file): with open(self.wal_file, "r+") as f: f.truncate(0) f.flush() os.fsync(f.fileno()) except IOError as e: if os.path.exists(temp_file): os.remove(temp_file) raise DatabaseError(f"Checkpoint failed: {e}")

In this implementation, we keep track of 2 files: the data_file and wal_file. The data_file serves as our permanent storage where we save all data periodically (like a database backup), while the wal_file acts as our transaction log where we record every operation before executing it (like a diary of changes).

\ The recovery process works like this:

  • First loads the last saved state from data_file
  • Then replays all operations from the WAL file
  • Handles both ‘set’ and ‘delete’ operations

\ Example of recovery sequence:

# data_file contains: {"user:1": {"name": "Alice"}} # wal_file contains: {"operation": "set", "key": "user:2", "value": {"name": "Bob"}} {"operation": "delete", "key": "user:1"} # After recovery, self.data contains: {"user:2": {"name": "Bob"}}

Another important method of this implementation is checkpoint. It creates a permanent snapshot of the current state. Here is an example of a checkpointing process:

1. Current state: data_file: {"user:1": {"name": "Alice"}} wal_file: [set user:2 {"name": "Bob"}] 2. During checkpoint: data_file: {"user:1": {"name": "Alice"}} data_file.tmp: {"user:1": {"name": "Alice"}, "user:2": {"name": "Bob"}} wal_file: [set user:2 {"name": "Bob"}] 3. After checkpoint: data_file: {"user:1": {"name": "Alice"}, "user:2": {"name": "Bob"}} wal_file: [] # Empty Memory Tables — Making It Fast

MemTables serve as our database’s fast write path, providing quick access to recently written data. They maintain sorted order in memory, enabling efficient reads and range queries while preparing data for eventual persistent storage. You can think of a MemTable as a sorting tray on your desk:

\

  • New items go here first (fast!).
  • Since all items are sorted, the search operation is fast using binary search for both single key queries and range queries (binary search cuts the search space in half with each step, giving us O(log n) performance instead of having to check every item).
  • When it gets full, you file them away properly (in SSTables).
class MemTable: def __init__(self, max_size: int = 1000): self.entries: List[Tuple[str, Any]] = [] self.max_size = max_size def add(self, key: str, value: Any): """Add or update a key-value pair""" idx = bisect.bisect_left([k for k, _ in self.entries], key) if idx < len(self.entries) and self.entries[idx][0] == key: self.entries[idx] = (key, value) else: self.entries.insert(idx, (key, value)) def get(self, key: str) -> Optional[Any]: """Get value for key""" idx = bisect.bisect_left([k for k, _ in self.entries], key) if idx < len(self.entries) and self.entries[idx][0] == key: return self.entries[idx][1] return None def is_full(self) -> bool: """Check if memtable has reached max size""" return len(self.entries) >= self.max_size def range_scan(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]: """Scan entries within key range""" start_idx = bisect.bisect_left([k for k, _ in self.entries], start_key) end_idx = bisect.bisect_right([k for k, _ in self.entries], end_key) return iter(self.entries[start_idx:end_idx])

Let’s see how the memtable stays sorted:

# Starting state: entries = [ ("apple", 1), ("cherry", 2), ("zebra", 3) ] # Adding "banana" = 4: # 1. Find insertion point (between "apple" and "cherry") # 2. Insert new entry # 3. Result: entries = [ ("apple", 1), ("banana", 4), ("cherry", 2), ("zebra", 3) ]

While our implementation uses a simple sorted list with binary search, production databases like LevelDB and RocksDB typically use more sophisticated data structures like Red-Black trees or Skip Lists. We used a simpler approach here to focus on the core concepts, but keep in mind, that real databases need these optimizations for better performance.

Sorted String Tables (SSTables) — Making It Scale

While MemTables provide fast writes, they can’t grow indefinitely. We need a way to persist them to disk efficiently. Since our memory is limited, when the MemTable gets full, we need to save it to disk. We use SSTables for this. SSTables (Sorted String Tables) are like sorted, immutable folders of data:

class SSTable: def __init__(self, filename: str): self.filename = filename self.index: Dict[str, int] = {} if os.path.exists(filename): self._load_index() def _load_index(self): """Load index from existing SSTable file""" try: with open(self.filename, "rb") as f: # Read index position from start of file f.seek(0) index_pos = int.from_bytes(f.read(8), "big") # Read index from end of file f.seek(index_pos) self.index = pickle.load(f) except (IOError, pickle.PickleError) as e: raise DatabaseError(f"Failed to load SSTable index: {e}") def write_memtable(self, memtable: MemTable): """Save memtable to disk as SSTable""" temp_file = f"{self.filename}.tmp" try: with open(temp_file, "wb") as f: # Write index size for recovery index_pos = f.tell() f.write(b"\0" * 8) # Placeholder for index position # Write data for key, value in memtable.entries: offset = f.tell() self.index[key] = offset entry = pickle.dumps((key, value)) f.write(len(entry).to_bytes(4, "big")) f.write(entry) # Write index at end index_offset = f.tell() pickle.dump(self.index, f) # Update index position at start of file f.seek(index_pos) f.write(index_offset.to_bytes(8, "big")) f.flush() os.fsync(f.fileno()) # Atomically rename temp file shutil.move(temp_file, self.filename) except IOError as e: if os.path.exists(temp_file): os.remove(temp_file) raise DatabaseError(f"Failed to write SSTable: {e}") def get(self, key: str) -> Optional[Any]: """Get value for key from SSTable""" if key not in self.index: return None try: with open(self.filename, "rb") as f: f.seek(self.index[key]) size = int.from_bytes(f.read(4), "big") entry = pickle.loads(f.read(size)) return entry[1] except (IOError, pickle.PickleError) as e: raise DatabaseError(f"Failed to read from SSTable: {e}") def range_scan(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]: """Scan entries within key range""" keys = sorted(k for k in self.index.keys() if start_key <= k <= end_key) for key in keys: value = self.get(key) if value is not None: yield (key, value)

File format explanation:

File layout: [size1][entry1][size2][entry2]... Example: [0x00000020][{"key": "apple", "value": 1}][0x00000024][{"key": "banana", "value": 4}]...

\

Putting It All Together — The LSM Tree

Now, we’ll combine everything we’ve learned (WAL, MemTable, and SSTables) into a simplified version of a Log-Structured Merge Tree (LSM Tree). While a true LSM tree uses multiple levels, we’ll start with a basic flat structure in Part 1 and upgrade to a proper leveled implementation in Part 2 of this series.

\ Imagine you’re organizing papers in an office:

  • New papers go into your inbox (MemTable).

  • This is like the fast “write path” of the database.

  • Papers can be added quickly without worrying about organization yet.

  • The inbox is kept in memory for quick access.

    \

  • When the inbox is full, you sort them and put them in a folder (SSTable).

  • The papers are now sorted and stored efficiently.

  • Once in a folder, these papers never change (immutable).

  • Each folder has its own index for quick lookups.

    \

  • When you have too many folders, you merge them (Compaction).

  • In our Part 1 implementation, we’ll simply merge all folders into one.

  • This is a simplified approach compared to real databases. While functional, it’s not as efficient as proper leveled merging.

    \

Here’s how we implement our simplified version: \n

class LSMTree: def __init__(self, base_path: str): self.base_path = Path(base_path) try: # Check if path exists and is a file if self.base_path.exists() and self.base_path.is_file(): raise DatabaseError(f"Cannot create database: '{base_path}' is a file") self.base_path.mkdir(parents=True, exist_ok=True) except (OSError, FileExistsError) as e: raise DatabaseError( f"Failed to initialize database at '{base_path}': {str(e)}" ) # Our "Inbox" for new data self.memtable = MemTable(max_size=1000) # Our "Folders" of sorted data self.sstables: List[SSTable] = [] self.max_sstables = 5 # Limit on number of SSTables self.lock = RLock() self.wal = WALStore( str(self.base_path / "data.db"), str(self.base_path / "wal.log") ) self._load_sstables() if len(self.sstables) > self.max_sstables: self._compact() def _load_sstables(self): """Load existing SSTables from disk""" self.sstables.clear() for file in sorted(self.base_path.glob("sstable_*.db")): self.sstables.append(SSTable(str(file)))

To ensure thread safety in our database, we use locks to prevent multiple threads from causing problems:

# Example of why we need locks: # Without locks: Thread 1: reads data["x"] = 5 Thread 2: reads data["x"] = 5 Thread 1: writes data["x"] = 6 Thread 2: writes data["x"] = 7 # Last write wins, first update lost! # With locks: Thread 1: acquires lock Thread 1: reads data["x"] = 5 Thread 1: writes data["x"] = 6 Thread 1: releases lock Thread 2: acquires lock Thread 2: reads data["x"] = 6 Thread 2: writes data["x"] = 7 Thread 2: releases lock Writing Data

Let’s see how writing data works step by step:

def set(self, key: str, value: Any): """Set a key-value pair""" with self.lock: if not isinstance(key, str): raise ValueError("Key must be a string") # 1. Safety first: Write to WAL self.wal.set(key, value) # 2. Write to memory table (fast!) self.memtable.add(key, value) # 3. If memory table is full, save to disk if self.memtable.is_full(): self._flush_memtable() def _flush_memtable(self): """Flush memtable to disk as new SSTable""" if not self.memtable.entries: return # Skip if empty # Create new SSTable with a unique name sstable = SSTable(str(self.base_path / f"sstable_{len(self.sstables)}.db")) sstable.write_memtable(self.memtable) # Add to our list of SSTables self.sstables.append(sstable) # Create fresh memory table self.memtable = MemTable() # Create a checkpoint in WAL self.wal.checkpoint() # Compact if we have too many SSTables if len(self.sstables) > self.max_sstables: self._compact()

\ Example of how data flows:

# Starting state: memtable: empty sstables: [] # After db.set("user:1", {"name": "Alice"}) memtable: [("user:1", {"name": "Alice"})] sstables: [] # After 1000 more sets (memtable full)... memtable: empty sstables: [sstable_0.db] # Contains sorted data # After 1000 more sets... memtable: empty sstables: [sstable_0.db, sstable_1.db] Reading Data

Reading needs to check multiple places, newest to oldest:

def get(self, key: str) -> Optional[Any]: """Get value for key""" with self.lock: if not isinstance(key, str): raise ValueError("Key must be a string") # 1. Check memtable first (newest data) value = self.memtable.get(key) if value is not None: return value # 2. Check each SSTable, newest to oldest for sstable in reversed(self.sstables): value = sstable.get(key) if value is not None: return value # 3. Not found anywhere return None def range_query(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]: """Perform a range query""" with self.lock: # Get from memtable for key, value in self.memtable.range_scan(start_key, end_key): yield (key, value) # Get from each SSTable seen_keys = set() for sstable in reversed(self.sstables): for key, value in sstable.range_scan(start_key, end_key): if key not in seen_keys: seen_keys.add(key) if value is not None: # Skip tombstones yield (key, value)

\n Example of reading:

# Database state: memtable: [("user:3", {"name": "Charlie"})] sstables: [ sstable_0.db: [("user:1", {"name": "Alice"})], sstable_1.db: [("user:2", {"name": "Bob"})] ] # Reading "user:3" -> Finds it in memtable # Reading "user:1" -> Checks memtable, then finds in sstable_0.db # Reading "user:4" -> Checks everywhere, returns None Compaction: Keeping Things Tidy

def _compact(self): """Merge multiple SSTables into one""" try: # Create merged memtable merged = MemTable(max_size=float("inf")) # Merge all SSTables for sstable in self.sstables: for key, value in sstable.range_scan("", "~"): # Full range merged.add(key, value) # Write merged data to new SSTable new_sstable = SSTable(str(self.base_path / "sstable_compacted.db")) new_sstable.write_memtable(merged) # Remove old SSTables old_files = [sst.filename for sst in self.sstables] self.sstables = [new_sstable] # Delete old files for file in old_files: try: os.remove(file) except OSError: pass # Ignore deletion errors except Exception as e: raise DatabaseError(f"Compaction failed: {e}")

\

# Before compaction: sstables: [ sstable_0.db: [("apple", 1), ("banana", 2)], sstable_1.db: [("banana", 3), ("cherry", 4)], sstable_2.db: [("apple", 5), ("date", 6)] ] # After compaction: sstables: [ sstable_compacted.db: [ ("apple", 5), # Latest value wins ("banana", 3), # Latest value wins ("cherry", 4), ("date", 6) ] ]

\ We also need other methods, such as deleting and closing the database instance:

def delete(self, key: str): """Delete a key""" with self.lock: self.wal.delete(key) self.set(key, None) # Use None as tombstone def close(self): """Ensure all data is persisted to disk""" with self.lock: if self.memtable.entries: # If there's data in memtable self._flush_memtable() self.wal.checkpoint() # Ensure WAL is up-to-date

\ Here’s how to use what we’ve built:

# Create database in the 'mydb' directory db = LSMTree("./mydb") # Store some user data db.set("user:1", { "name": "Alice", "email": "[email protected]", "age": 30 }) # Read it back user = db.get("user:1") print(user['name']) # Prints: Alice # Store many items for i in range(1000): db.set(f"item:{i}", { "name": f"Item {i}", "price": random.randint(1, 100) }) # Range query example print("\nItems 10-15:") for key, value in db.range_query("item:10", "item:15"): print(f"{key}: {value}") Testing Our Database

Let’s put our database through some basic tests to understand its behavior:

def test_basic_operations(db_path): db = LSMTree(db_path) # Test single key-value db.set("test_key", "test_value") assert db.get("test_key") == "test_value" # Test overwrite db.set("test_key", "new_value") assert db.get("test_key") == "new_value" # Test non-existent key assert db.get("missing_key") is None def test_delete_operations(db_path): db = LSMTree(db_path) # Test delete existing key db.set("key1", "value1") assert db.get("key1") == "value1" db.delete("key1") assert db.get("key1") is None # Test delete non-existent key db.delete("nonexistent_key") # Should not raise error # Test set after delete db.set("key1", "new_value") assert db.get("key1") == "new_value" def test_range_query(db_path): db = LSMTree(db_path) # Insert test data test_data = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} for k, v in test_data.items(): db.set(k, v) # Test range query results = list(db.range_query("b", "d")) assert len(results) == 3 assert results == [("b", 2), ("c", 3), ("d", 4)] # Test empty range results = list(db.range_query("x", "z")) assert len(results) == 0 def test_concurrent_operations(db_path): db = LSMTree(db_path) def writer_thread(): for i in range(100): db.set(f"thread_key_{i}", f"thread_value_{i}") sleep(0.001) # Small delay to increase chance of concurrency issues # Create multiple writer threads threads = [Thread(target=writer_thread) for _ in range(5)] # Start all threads for thread in threads: thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify all data was written correctly for i in range(100): assert db.get(f"thread_key_{i}") == f"thread_value_{i}" def test_recovery(db_path): """Test database recovery with optimized operations""" # Create a database instance db1 = LSMTree(db_path) # Add some data to the database db1.set("key1", "value1") db1.set("key2", "value2") assert db1.get("key1") == "value1" assert db1.get("key2") == "value2" # Close the database to force a flush db1.close() # Create a new instance of the database db2 = LSMTree(db_path) # Verify data is still present after recovery value1 = db2.get("key1") value2 = db2.get("key2") assert value1 == "value1" assert value2 == "value2" db2.close()

\n Running these tests helps ensure our database is working as expected!

Conclusion

In this first part of our journey to build a database from scratch, we’ve implemented a basic but functional key-value store with several important database concepts:

\

  • Write-Ahead Logging (WAL) for durability and crash recovery
  • MemTable for fast in-memory operations with sorted data
  • Sorted String Tables (SSTables) for efficient disk storage
  • Log-Structured Merge (LSM) Tree to tie everything together

\ Our implementation can handle basic operations like setting values, retrieving them, and performing range queries while maintaining data consistency and surviving crashes. By converting random writes into sequential ones, LSM Trees excel at quickly ingesting large amounts of data. This is why databases like RocksDB (Facebook), Apache Cassandra (Netflix, Instagram), and LevelDB (Google) use LSM Trees for their storage engines.

\ In contrast, traditional B-Tree structures (used by PostgreSQL and MySQL) offer better-read performance but may struggle with heavy write loads. We’ll explore B-Tree implementations in future posts to better understand these trade-offs.

\ Current Limitations

While our database works, it has several limitations that we’ll address in the coming parts:

\ Storage and Performance:

  • Simple list of SSTables instead of a leveled structure.
  • Inefficient compaction that merges all tables at once.
  • No way to skip unnecessary SSTable reads; scanning all tables for a query is inefficient.

\ Concurrency:

  • Basic locking that locks the entire operation
  • No support for transactions across multiple operations
  • The compaction process blocks all other operations

\ In Part 2, we’ll tackle some of these limitations by implementing level-based compaction, bloom filters for faster lookups, and basic transaction support. Stay tuned.

\ The complete source code for this tutorial is available on GitHub. I encourage you to experiment with the code, try different optimizations, and share your findings in the comments below. You can also subscribe to my personal blog for more frequent updates.

Further Reading

If you want to dive deeper into these concepts before Part 2, here are some resources:

\