前言

为了深入学习Rust编程的最佳实践,并熟悉成熟且主流的tokio异步运行时,我决定通过学习mini-redis项目来提升自己的Rust高性能编程能力。本文记录了我对该项目的学习过程和理解。

项目地址:https://github.com/tokio-rs/mini-redis

项目架构概览

graph TD
    %% Main components with subgraph structure
    subgraph 核心服务器 ["核心服务器组件"]
        server["server
Redis服务器实现"] shutdown["shutdown
优雅关闭机制"] db["db
键值存储和发布/订阅"] end subgraph 通信层 ["通信层组件"] connection["connection
TCP连接管理"] parse["parse
TCP字节流解析"] frame["frame
Redis协议帧表示"] end subgraph 客户端交互 ["客户端交互组件"] clients["clients
异步和阻塞客户端"] cmd["cmd
Redis命令实现"] end %% Relationship arrows with labels server -->|"初始化和管理"| connection server -->|"创建和维护"| db server -->|"触发"| shutdown connection -->|"发送/接收"| frame parse -->|"生成"| frame clients -->|"建立"| connection clients -->|"发出"| cmd cmd -->|"编码为"| frame db -->|"响应转换为"| frame %% Add clear data flow frame -->|"执行并更新"| db connection -->|"传递到"| parse %% Visual styling classDef core fill:#f9d5e5,stroke:#333,stroke-width:1px classDef comm fill:#eeeeee,stroke:#333,stroke-width:1px classDef client fill:#d0f0c0,stroke:#333,stroke-width:1px class server,shutdown,db core class connection,parse,frame comm class clients,cmd client
  1. 核心服务器组件

    • 🔹 server:Redis服务器核心,启动服务进程,管理连接和请求处理流程
    • 🔹 db:实现键值存储引擎,管理数据结构和发布/订阅功能
    • 🔹 shutdown:处理服务器正常关闭流程,确保数据完整性和连接优雅终止
  2. 通信层组件

    • 🔹 connection:管理TCP连接生命周期,处理网络I/O和事件循环
    • 🔹 parse:将TCP字节流解析为协议格式,处理分包和粘包问题
    • 🔹 frame:Redis协议帧的编码解码器,转换命令与二进制表示
  3. 客户端交互组件

    • 🔹 clients:提供异步和阻塞式客户端API,处理连接池和请求队列
    • 🔹 cmd:实现Redis命令集,处理命令验证、执行和响应生成

服务器启动时初始化各组件,建立连接监听和处理管道。客户端连接请求经由connection组件处理,建立会话。客户端命令经过协议编码,通过连接发送到服务器。服务器解析命令后在db组件中执行,并将结果返回。所有组件共同协作,确保数据流转高效和错误处理完善。

先看一下 mini-redis 的基本功能,具体在 README 里,然后再逐步实现。最简单的读、写、Ping、以及订阅更新的功能。

实现

db 模块

数据结构和后台任务

classDiagram
    class Db {
        +new() Db
        +get(key: String) Option~Bytes~
        +set(key: String, value: Bytes, expiration: Option~Duration~) void
        +subscribe(channel: String) Receiver~Bytes~
        +publish(channel: String, message: Bytes) usize
    }

    class Shared {
        -state: Mutex~State~
        -background_task: Notify
    }

    class State {
        -entries: HashMap~String, Entry~
        -pub_sub: HashMap~String, Sender~Bytes~~
        -expirations: BTreeSet~(Instant, String)~
        -shutdown: bool
    }

    class Entry {
        -data: Bytes
        -expires_at: Option~Instant~
    }

    class BackgroundTask {
        -run() async
        -expire_keys() usize
        -sleep_until_next_expiration() Future
    }

    Db *-- Shared : contains
    Shared *-- State : protects
    State *-- Entry : stores
    State o-- BackgroundTask : triggers

    note for BackgroundTask "后台任务负责清理过期键值对"
    note for Shared "使用Arc包装,允许多线程共享访问"
    note for Entry "存储值和过期时间"

Redis 是一个基于键值对的数据结构服务器,它支持多种类型的值,而且我们可以为每个键设置过期时间,到了这个时间点,如果键还没有被更新,它会被自动从数据库中删除。我们的值全部当作是 Bytes 类型,过期时间设置在 value 里。注意有的值是永不过期的,所以 expires_at 是 Option 类型。

1
2
3
4
5
6
7
8
9
10
/// Entry in the key-value store
#[derive(Debug)]
struct Entry {
/// Stored data
data: Bytes,

/// Instant at which the entry expires and should be removed from the
/// database.
expires_at: Option<Instant>,
}

然后需要维护整个 db 状态的变量,它需要不断的扫描过期时间,删除过期的键值对。所以需要使用有序的 时刻->键 的映射,这里使用 BTreeSet 来实现。

Redis 的 pub/sub 可以设置不同的频道,然后相同的频道可以有多个订阅者,就是一个广播机制。所以就有 频道 -> 广播 的映射。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#[derive(Debug)]
struct State {
/// 键值数据。使用标准的HashMap即可满足需求
entries: HashMap<String, Entry>,

/// 发布/订阅键空间。Redis为键值和发布/订阅使用单独的键空间
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,

/// 跟踪键的TTL
/// 使用BTreeSet按过期时间排序
/// 这允许后台任务遍历此映射以找到下一个过期的值
expirations: BTreeSet<(Instant, String)>,

/// 当Db实例关闭时为true
/// 设置为true时通知后台任务退出
shutdown: bool,
}

触发过期的检查是在后台线程中进行的,一般是2个触发场景,第一个是后台定时触发,第二个是有新的键值对插入,如果时间比之前的都早,这时候需要更新过期时间。所以需要一个 Notify 来通知后台线程。定时器触发一般注意2点:

  1. 没有任务时,线程应该休眠,不要空转。
  2. 如果有任务,等待的时间应该是下一个的过期时间,而不是固定的时间间隔。

读和写

读和写都要加锁,因为读写都会修改 entries,所以需要 Mutex 来保护。读的时候数据要 clone 一份,因为哈希表里还是存储着的。

写的时候,如果指定了过期时间,那么需要更新 expirations,如果之前有过期时间,那么需要删除之前的,然后更新。

Pub/Sub

订阅很简单,创建一个 broadcast::channel,返回给客户端一个 Receiver,然后把 Sender 存储到 pub_sub 里。发布的时候,使用相同的频道,发送消息即可。

Redis协议帧

Redis 协议是一个简单的文本协议,它是基于 TCP 的,所以是字节流。我们需要把字节流解析成 Redis 命令,然后执行,然后把结果序列化成字节流返回给客户端。

RESP 是 Redis 客户端和服务器之间通信的协议,设计简单且易于实现。RESP 支持 5 种基本数据类型:

  1. 简单字符串 (Simple Strings)
    • 格式:+<string>\r\n
    • 例子:+OK\r\n
    • 说明:不能包含换行符。
  2. 错误 (Errors)
    • 格式:-<error message>\r\n
    • 例子:-ERR unknown command 'foobar'\r\n
    • 说明:客户端应将其视为异常。
  3. 整数 (Integers)
    • 格式::<number>\r\n
    • 例子::1000\r\n
    • 说明:64 位有符号整数。
  4. 批量字符串 (Bulk Strings)
    • 格式:$<length>\r\n<data>\r\n
    • 例子:$5\r\nhello\r\n
    • 说明:
      • 可以表示二进制数据。
      • 空字符串:$0\r\n\r\n
      • 空值:$-1\r\n
  5. 数组 (Arrays)
    • 格式:*<number of elements>\r\n<element1>...<elementN>
    • 例子:*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n
    • 说明:
      • 可以包含不同类型的元素。
      • 空数组:*0\r\n
      • 空值数组:*-1\r\n

客户端请求通常使用数组格式发送命令,例如 SET key value 被编码为:
*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n

服务器响应根据命令返回适当的 RESP 数据类型,例如 GET key 可能返回 $5\r\nhello\r\n 或者 $-1\r\n

根据上面的特征,帧解析的时候,首先读取第一个字节确定类型,然后基于类型解析剩余数据。而且这是典型的字符串状态机,所以可以使用状态机模式来实现,逐步构建完整的 RESP 对象。

1
2
3
4
5
6
7
8
9
10
/// A frame in the Redis protocol.
#[derive(Clone, Debug)]
pub enum Frame {
Simple(String),
Error(String),
Integer(u64),
Bulk(Bytes),
Null,
Array(Vec<Frame>),
}

+<string>\r\n-<error message>\r\n 都比较简单,一直找到 \r\n 为止,然后返回 Simple 类型。
如果是整数 :<number>\r\n,则需要额外的判断是否能编码成 u64。
如果是二进制数据 $<length>\r\n<data>\r\n,要注意空字符串和空值的情况。
如果是数组,那么第一个部分是整数格式,然后递归解析每个元素。

连接模块

连接模块实现网络读写Redis帧的功能,基于tokio的异步网络库。一个基本连接需实现:

  • 初始化连接:建立底层网络连接
  • 读取数据:从网络连接读取数据
  • 写入数据:向网络连接写入数据
  • 处理数据:解析和处理接收到的数据
  • 维护连接状态:监控连接状态,处理各种事件

我们使用BufWriter而非直接使用TcpStream,它维护内部缓冲区,优化系统交互,减少系统调用。同时使用BytesMut维护读取但未处理的数据,实现零拷贝提升性能。

因为 TCP 协议只保证字节流的顺序,不保证数据包的边界。已经读取但是还没有处理的数据,我们需要维护一个缓冲区,而不是频繁访问系统的TCP缓冲区。常见的做法是使用 Vec 来维护,但是 BytesMut 实现了零拷贝,只有在实际需要修改的时候才会拷贝,在读取、合并等操作的时候,只是移动指针,所以性能更好。

1
2
3
4
5
6
7
8
9
10
#[derive(Debug)]
pub struct Connection {
// The `TcpStream`. It is decorated with a `BufWriter`, which provides write
// level buffering. The `BufWriter` implementation provided by Tokio is
// sufficient for our needs.
stream: BufWriter<TcpStream>,

// The buffer for reading frames.
buffer: BytesMut,
}

当读到一个数据流的时候,要尝试解析成一个完整的 RESP 帧,如果解析成功,就返回一个 Frame,
然后把剩余的数据放回 buffer 里。具体解析的时候,使用 Cursor 来读取 buffer,方便设置 position,处理字节流。如果构成一个完整的 RESP 帧,就要记得 advance,把已经处理的字节去掉,这段内存空出来。此时 buffer 又开始从 0 开始。

如果解析失败,就继续读取,直到解析成功。考虑特殊情况,如果读取字节为0,说明对端关闭了连接,如果此时buffer还存在数据,就说明是异常断开了。

写入的时候,需要把Frame序列化成字节流,然后写入到 stream 里,要注意写入的字节流可能比较大,所以要分步写入,直到全部写入完成。最后记得 flush,把缓冲区的数据写入到系统的 TCP 缓冲区。

服务器模块

这个模块就开始汇总前面的模块,实现一个完整的 Redis 服务器。
一般都是一个很大的结构体,然后实现一些需要对外部暴露的方法,比如启动、关闭、处理连接等。内部的各个子结构体,就是各个功能模块,接着在这个大结构体里逐渐启动。大结构体里,除了各个模块的实例,还有一些共享的数据,比如配置、日志、计数器,尤其是一些控制信号,用于协调各个模块的工作,比如停止服务的顺序。这个设计模式叫做:中介者模式。

然后一个模块有更新,比如协议更新,出现了并存的实例,那么这个局部可以使用外观模式,把这些实例隐藏起来,对外暴露一个接口,这个接口可以根据配置,选择不同的实例。

控制信号需要 2 个,一个用于通知连接退出,因为 redis 服务器可能有多个客户端连接,每个连接都有自己的 TCP 连接进来,
当服务器退出的时候,需要通知所有的连接退出。
另一个是通知后台线程退出,比如数据库的部分,它在不断的扫描过期时间,删除过期的键值对,这个是一个后台线程,需要通知它退出。

另外一个控制信号是并发控制,限制服务器接收的 TCP 连接数。

1
2
3
4
5
6
7
8
9
struct Listener {
db_holder: DbDropGuard,

listener: TcpListener,

limit_connections: Arc<Semaphore>,
notify_shutdown: broadcast::Sender<()>,
shutdown_complete_tx: mpsc::Sender<()>,
}

一般启动服务器时,要考虑控制信号,还有一些程序的初始化,比如日志,
一些rust运行时的参数,比如线程数,这些都是全局的,需要在启动的时候初始化。

所以最外层的结构体,一般是这样的,shutdown 收到之后,会退出这个select:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
tokio::select! {
res = server.run() => {
// If an error is received here, accepting connections from the TCP
// listener failed multiple times and the server is giving up and
// shutting down.
//
// Errors encountered when handling individual connections do not
// bubble up to this point.
if let Err(err) = res {
error!(cause = %err, "failed to accept");
}
}
_ = shutdown => {
// The shutdown signal has been received.
info!("shutting down");
}
}

整个服务器是个大循环,首先需要信号量有空闲,才允许新的连接进来,
然后监听新的连接,把连接和后台服务模块打包进 Handler,这样启动一个异步的任务。
这是个非常常用的程序设计思路,把一个任务的启动和关闭,都封装在一个结构体里,这样可以方便的控制任务的生命周期。
而且也方便在另外一个线程里启动这个任务。

这里通过封装 TcpListener 自定义了 accept,实现了指数退避策略,如果 accept 失败,就等待一段时间再尝试,这样可以减少系统调用,提高性能。
一般 accept 失败都是系统内部错误。

这里就涉及到克隆的规则了,因为我们希望每个连接都共享数据库实例。rust的clone 分成了深拷贝和浅拷贝,shallow clone 会增加引用计数,deep clone 会复制整个对象。这里我们需要的是浅拷贝,所以使用 Arc 来包装数据库实例。

对于结构体这种复合类型,其克隆行为遵循以下规则:

  1. 默认派生的Clone实现 (#[derive(Clone)]):会递归地克隆结构体中的每个字段,每个字段的克隆行为取决于该字段类型自己的Clone实现,最终结果是结构体中所有字段都被克隆。
  2. 手动实现的Clone:可以自定义任何克隆行为,可以选择性地克隆某些字段或使用不同的克隆策略

📌 例1:所有字段进行深克隆的结构体

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#[derive(Clone)]
struct Person {
name: String,
age: u32,
hobbies: Vec<String>,
}

fn main() {
let person1 = Person {
name: String::from("Alice"),
age: 30,
hobbies: vec![String::from("Reading"), String::from("Hiking")],
};

let person2 = person1.clone();
}

这里Person的克隆会创建一个完全独立的副本,因为所有字段(Stringu32Vec<String>)都实现了深克隆。

📌 例2:包含引用计数的结构体(混合克隆行为)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
use std::sync::Arc;

#[derive(Clone)]
struct Document {
title: String, // 深克隆
content: String, // 深克隆
shared_metadata: Arc<Metadata>, // 浅克隆(只增加引用计数)
}

#[derive(Clone)]
struct Metadata {
author: String,
created_at: u64,
tags: Vec<String>,
}

fn main() {
let metadata = Arc::new(Metadata {
author: String::from("Bob"),
created_at: 1616161616,
tags: vec![String::from("important")],
});

let doc1 = Document {
title: String::from("Report"),
content: String::from("Some content..."),
shared_metadata: metadata,
};

let doc2 = doc1.clone();

// doc2的title和content是深克隆的独立副本
// 但doc1和doc2共享同一个metadata (Arc<Metadata>)
}

这里Document的克隆是混合行为:

  • titlecontent字段是深克隆(完整副本)
  • shared_metadata字段是浅克隆(引用计数增加)

每个 Handler 里除了必要的资源,还有退出信号、每个连接的信号。只有当所有连接的信号都关闭,服务器才会退出。
当每个连接收到退出信号,就会开始关闭连接,也就是 Handler 任务。

任务的主要内容就是读取 Frame,一定会停留到收到一个完整的 Frame,才进行下一步。
如果出现异常,就会 Err,如果客户端政策终止连接返回None。
接着 Frame 会被解析成命令,然后执行,执行的结果会被序列化成 Frame,然后写入到连接里。

这里用 enum 来表示命令,然后用 match 来处理,这是 Rust 的模式匹配,很好的实践,能选择不同的成员,这样各自单独实现同名的命令,这就类似接口了。

另外每个命令有不同的 Frame 组装形式和返回值形式。

这里就涉及到 rust 的模块组织了,

  1. 一个模块就是一个文件,文件名就是模块名,文件里的内容就是模块的内容。
  2. 一个模块可以包含多个结构体、枚举、函数等。
  3. 一个文件里可以包含多个模块,这些模块可以是私有的,也可以是公有的。跨文件实际上就是隐含的把文件名当作模块名,然后在其他文件里引用。目录也是同理,目录名就是模块名。

模块组织有 mod.rs 和 与目录同名的rs文件两种做法,Rust 2018推荐之后推荐后者,但是实际上两者差别不大,只是文件命名位置不同而已。

1
2
3
4
5
6
src/
├── main.rs
└── models/
├── mod.rs
├── user.rs
└── product.rs
1
2
3
4
5
6
src/
├── main.rs
├── models.rs
└── models/
├── user.rs
└── product.rs

区别只是,mod.rs 挪动到了同级目录的与目录同名的rs文件里,你不用动里面的内容,rust会自动的去目录中找对应文件。

🔥 Rust 导入还有技巧,可以先导入,在设置Pub哪些。

1
2
3
4
5
mod get;
pub use get::Get;

mod publish;
pub use publish::Publish;

客户端模块

Connection 是最基本的,包括 TCPStream 和 Buffer。接着用户命令行输入组装成对应命令,命令再序列化成 Frame,然后写入到 Connection 里。
接着阻塞,等待服务器返回 frame。但是要注意超时,这里是如果发送成功后,没有考虑对方超时。我们注释掉返回写入的逻辑,会发现客户端卡住了。后面我们自己来修改。

Ping->Pong, 都是简单的字符串,所以直接返回即可。
Get->Value,发送的是 bulk 数组,返回Simple、Bulk、Null都有可能。
Set->OK,返回的是 Simple。

订阅会麻烦一些,字符串列表表示要订阅的多个频道。
服务端首先要记录客户端订阅了哪些 channel,并且为每个 channel 创建一个 stream。stream 是等待后续消息的流,消息通道,这个通道是rx,接收信息。而服务端会在状态数据里记录对应的tx,等待有其他客户端Publish消息到对应channel,然后广播给rx。

这个实现有个好处,rx 是从状态数据库来的,每个客户端都持有对应的部分,那么就可以做到状态数据库向tx发送消息,然后所有的客户端都能收到消息。这就是广播机制。

这个 rx 的类型很有意思,loop 与 yield 一起用,这就是生成器,但是加上 async_stream::stream! 就是异步的生成器,这个是通过 async-stream 包提供的,可以实现异步的流。不过同步代码不会用 yield,而是实现一个 Iterator 的 next 方法。所以这个语法几乎是创建异步流的标准写法。
每次收到消息,就 yield 出去,然后等待下一次消息,退出就返回 None,这和 stream 完全相同。那么持有这个 loop 的变量,就可以实现 stream 的功能。
try_stream! 是 async_stream 提供的另一个重要宏,它与 stream! 类似,但专门用于处理可能出错的场景。它返回的流中的元素类型是 Result<T, E>,并且在宏内部可以使用 ? 操作符进行错误传播。这对于像网络操作、IO 读写等可能失败的异步操作特别有用。
当我们使用 try_stream! 时,内部的错误会被自动包装成 Result 类型并向上传播。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
let mut rx = db.subscribe(channel_name.clone());

// Subscribe to the channel.
let rx = Box::pin(async_stream::stream! {
loop {
match rx.recv().await {
Ok(msg) => yield msg,
...
}
}
});
// Track subscription in this client's subscription set.
subscriptions.insert(channel_name.clone(), rx); // &mut StreamMap<String, Messages>

在底层实现上,async_stream 通过巧妙的方式解决了异步状态管理的问题。它使用线程本地存储来保存状态,避免了手动实现 Stream 特性时需要使用 unsafe 代码处理自引用结构的复杂性。这使得开发者可以用看起来像同步代码的方式(使用 async/await 和 yield)来编写异步流,而不必关心底层的异步状态机实现细节。

还有个有意思的地方,就是 stream 的类型 Box::pin(...),这创建一个堆分配的 stream,这是因为 async_stream::stream! 创建的 stream 类型和大小是不确定的,编译时才产生的。

这样封装一层Box,就可以当作一个固定大小的类型来使用了。记住,Rust 是强类型语言,每个变量都有固定的类型和大小,无法确定大小的类型一定是在堆上,而且用栈上的指针引用,这样才能保证内存安全。

这里的 pin 也是有学问的,异步本质上是创建了一个 Future,表示还没有完成的计算,是一个特殊 trait。在编译的时候,每个异步任务转换成状态机,每个状态就是异步任务的一个等待点。这个状态机可能包括对数据的引用,当这个 Future 被调度到其他线程,内存位置就可能变化,导致引用失效。所以需要 pin 保证内存位置不变,这样就可以安全的在多线程之间传递 Future。

Box::pin 是标准库提供的方法,用于创建堆分配的、被钉住的值。tokio::pin! 是 tokio 提供的宏,用于在栈上钉住值。最重要的区别在于tokio::pin! 分配在栈上,更高效,无需堆分配,但是生命周期受限于当前作用域,无法跨函数边界传递。

什么时候需要Pin呢,最常见的就是 stream,因为 stream 本身就是为了等待异步任务的结果,所以它本身就是一个Future。

还有就是自引用的结构,比如

1
2
3
4
5
struct SelfReferential {
data: String,
// This would normally be unsafe without Pin
reference: *const String,
}

下面函数的返回值类型也是 rust 的重要知识点。 前面的 rx 推理出的的类型是 Pin<Box<AsyncStream<Bytes, impl Future<Output = ()>>>>,实际定义的类型是 Pin<Box<dyn Stream<Item = Bytes> + Send>>,下面函数返回的类型是 impl Stream<Item = crate::Result<Message>>,我们发现了不用的写法 <Box<dyn _>>impl _,这里就是动态分发和静态分发的区别,动态分发是在运行时确定类型,可以返回不同的具体类型实现(多态),有轻微的性能开销,用Box实现返回值类型大小固定(就是一个指针大小)。而静态分发是在编译时确定类型,只能返回单一具体类型,性能更好。into_stream 用静态分发的原因是它专门处理固定的 Message 类型。而前面动态分发是因为需要处理多种可能的 Stream 实现,而这些实现在编译时可能并不确定。

1
2
3
4
5
6
7
pub fn into_stream(mut self) -> impl Stream<Item = crate::Result<Message>> {
try_stream! {
while let Some(message) = self.next_message().await? {
yield message;
}
}
}

客户端的虽有订阅的rx都存储在 subscriptions = StreamMap::new(); 里了,就是前面提到的 str->rx 的映射。他的好处是,任何一个rx收到消息,subscriptions.next() 就会返回这个消息,不用自己手动的写个 select 来等待多个rx。

另外用户可能再次订阅更多的频道,那么服务端继续增加 channels,然后触发增加 rx。

如果用户要取消订阅,那么服务端就要删除对应的 rx,这里有个问题,就是 rx 是异步的,可能正在等待消息,这时候删除了,就会导致 rx 无法接收到消息,这里我们直接中断rx,这是符合预期的,因为用户取消订阅,就不应该再接收到消息了。

我们来看 publish 的实现,是怎么发送给 channel 对应的 tx 的。这个操作就相当简单了,因为只管发布就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
async fn subscribe_to_channel(
channel_name: String,
subscriptions: &mut StreamMap<String, Messages>,
db: &Db,
dst: &mut Connection,
) -> crate::Result<()> {
let mut rx = db.subscribe(channel_name.clone());

// Subscribe to the channel.
let rx = Box::pin(async_stream::stream! {
loop {
match rx.recv().await {
Ok(msg) => yield msg,
// If we lagged in consuming messages, just resume.
Err(broadcast::error::RecvError::Lagged(_)) => {},
Err(_) => break,
}
}
});

// Track subscription in this client's subscription set.
subscriptions.insert(channel_name.clone(), rx);

// Respond with the successful subscription
let response = make_subscribe_frame(channel_name, subscriptions.len());
dst.write_frame(&response).await?;

Ok(())
}

改进

改动见: https://github.com/learnerLj/mini-redis/commit/dd18c65b17347e1986efc2df30f9d9f73f595086

接下来我打算在这个项目的基础上进行一些扩展。这里有个关键,就是不会清理掉过期的 broadcast::Sender。为了解决这个问题,我们首先考虑 Sender 是否知晓自己有rx呢?这个是不可能的,因为没有直接的方法来检测,只能发送消息去尝试。

1
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,

那我我们增加这么个机制。下面2个条件都会清理掉对应的 Sender。

  1. 通过 TCP 连接的状态来管理订阅的有效性。如果客户端断开连接,服务器会自动检测到并清理相应的资源。注意心跳机制没有作用,这里是TCP连接,一般心跳机制是用于UDP,检验某些记录是否还有效。

  2. Sender增加一个字段,是rx的计数。对应掉线或者主动取消订阅,都会减少计数,如果没有对应的 rx,就清理掉对应的 Sender。这个是为了防止客户端取消订阅,但是服务端没有收到消息,导致的内存泄漏。

具体说,订阅时,要更新对应字段。重复订阅已经存在的 channel,也要刷新活跃时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Entry::Occupied(e) =>
{
let channel_state = e.into_mut();
channel_state.receiver_count += 1;
channel_state.last_activity = Instant::now();
channel_state.sender.subscribe()

},
Entry::Vacant(e) => {
// No broadcast channel exists yet, so create one.
//
// The channel is created with a capacity of `1024` messages. A
// message is stored in the channel until **all** subscribers
// have seen it. This means that a slow subscriber could result
// in messages being held indefinitely.
//
// When the channel's capacity fills up, publishing will result
// in old messages being dropped. This prevents slow consumers
// from blocking the entire system.
let (tx, rx) = broadcast::channel(1024);
let channel_state = ChannelState {
sender: tx,
receiver_count: 1,
last_activity: Instant::now(),
};
e.insert(channel_state);
rx
}

取消订阅的时候是类似的,但是由于之前 db 模块没有这部份逻辑,我们需要从 unsubscribe 命令,一直改到 db。之前只是从 subscriptions 里删除这个 channel->stream 映射,这样subscriptions.next() 就不会从删除了的rx返回数据了。但是这次我们要db里也改。

这里有个有趣的点,参数db 没有mut,但是实际订阅是修改了 db的状态的,这是否矛盾了呢?没有,因为 Mutex 是Refcell的多线程版本,也就是也实现了内部可变行。Db 结构体包含一个 Arc,而 Shared 结构体包含一个 Mutexlet mut state = self.shared.state.lock().unwrap(); 就自动的变成可变的了。这里就是很好的例子,我们的Db不想从头到尾都是可变参数传进去,到了 state 字段就可变了。

1
2
3
4
5
6
7
async fn subscribe_to_channel(
channel_name: String,
subscriptions: &mut StreamMap<String, Messages>,
db: &Db,
dst: &mut Connection,
) -> crate::Result<()> {
let mut rx = db.subscribe(channel_name.clone());

db 增加一个实现,当所有订阅被取消了,就要删除这个订阅。

1
2
3
4
5
6
7
8
9
10
11
12
pub(crate) fn unsubscribe(&self, key: &str) {
let mut state = self.shared.state.lock().unwrap();

if let Some(channel_state) = state.pub_sub.get_mut(key) {
channel_state.receiver_count -= 1;
channel_state.last_activity = Instant::now();

if channel_state.receiver_count == 0 {
state.pub_sub.remove(key);
}
}
}

但是也可能存在一些客户端,已经掉线了,但是订阅还存在,那需要在掉线的时候,增加取消订阅的逻辑。中断时 dst.read_frame() 会返回 none,那么所有订阅都应该删除,通知server对应计数改变。

1
2
3
4
5
6
7
8
9
10
11
res = dst.read_frame() => {
let frame = match res? {
Some(frame) => frame,
// This happens if the remote client has disconnected.
None => {
subscriptions.keys().for_each(|channel_name| {
db.unsubscribe(channel_name);
});
return Ok(())
}
};

为了验证我们的订阅管理是否真的起效,我们增加 debug trace。接着把 mod.rs 的模式,我们改成现代的模式。

学习项目结构

还有项目的布局,我们发现 example 文件夹下,也能导入本地的 mini_redis,这说明寻找依赖库时,Cargo 首先会读取项目根目录下的 Cargo.toml 文件,它会检查 [package] 部分的 name 字段,这定义了项目/库的名称。当在代码中使用 use mini_redis 这样的导入语句时,Cargo 会首先检查这个名称是否与当前项目名称匹配
如果匹配,它会优先使用当前项目的库代码,如果不匹配,才会去查找外部依赖。

我们还发现了运行命令 cargo run --example sub 结构特殊,这是运行一级目录下的 sub模块,如果是sub.rs单个文件,就是里面的main函数。如果是一个文件夹,那么就是运行整个模块,一般入口 example/sub/main.rs

这里的项目层级,只要你在项目目录内,就不会受到当前路径的影响,你可以在项目的任何地方运行,会自动寻找。但是 Rust 对于大型项目,有一个"工作空间"(Workspace),那么你就需要指定项目了。在 workspace 根目录运行:需要指定包 cargo run --package project-a --example sub

1
2
3
4
5
6
7
8
workspace/
├── Cargo.toml # 工作空间配置
├── project-a/
│ ├── Cargo.toml
│ └── examples/
│ └── sub.rs
└── project-b/
└── Cargo.toml

如何判断是一个"工作空间"(Workspace)呢?根目录有一个主 Cargo.toml,定义工作空间和成员项目。比如reth项目,cargo build、cargo test 等命令时设置了默认包,

1
2
3
4
5
6
7
8
9
[workspace]
members = [
"bin/reth-bench/",
"bin/reth/",
"crates/chain-state/",
...
]
efault-members = ["bin/reth"]
exclude = ["book/sources", "book/cli"]

项目里还有 bin目录,用来用上各个模块的功能,完成这个工具,一般也是软件的入口。这里是因为有多个软件,客户端和服务端的原因。如果只有一个,一般寻找 src/main.rs 或者 src/bin/main.rs 作为默认入口。

1
2
3
4
5
6
7
[[bin]]
name = "mini-redis-cli"
path = "src/bin/cli.rs"

[[bin]]
name = "mini-redis-server"
path = "src/bin/server.rs"

对于简单的项目,用这样的布局足够了。workspace就单独学习。

再看项目导出的部分,都在 src/lib.rs 中。每个子模块导出的函数,就在mod.rs或者与目录同名的rs文件。

学习测试

下面是典型的测试,对于用到 async 的函数,都要用 #[tokio::test] 代替 #[test]

1
2
3
4
5
6
7
8
9
10
/// A PING PONG test without message provided.
/// It should return "PONG".
#[tokio::test]
async fn ping_pong_without_message() {
let (addr, _) = start_server().await;
let mut client = Client::connect(addr).await.unwrap();

let pong = client.ping(None).await.unwrap();
assert_eq!(b"PONG", &pong[..]);
}

测试的断言,除了 assert_eq! 还有 #[should_panic],标记在测试函数上,表示需要触发panic。

有时候需要便边写代码边测试,但是我们又不希望测试用例也编译进程序,那么创建一个 test模块,然后 #[cfg(test)] 标记它,这样仅在测试构建时包含的代码块。

1
2
3
4
5
6
7
8
9
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}

要运行所有测试,直接 cargo test,会寻找项目中所有标记为测试的代码,而不仅仅是 tests 目录下的测试。标记为测试是指:

  1. 标记为 #[test] 的函数,这些函数必须返回 () 并且不能接受任何参数。
  2. 标记为 #[tokio::test] 或其他自定义测试宏的函数。
  3. #[cfg(test)] 模块中的测试函数。

如果你只想运行 tests 目录下的测试,单个模块 cargo test --test。如果不在 tests 目录下,那么就按照模块顺序去找,比如 cargo test crate::models::user::test_user_validation。也可以简单地提供测试函数的名称,如果它在项目中是唯一的 cargo test test_user_validation。这里的名字都是可以正则匹配的。

如果你需要看到测试中的 println! 输出,可以添加 – --nocapture 参数。比如 cargo test --test server key_value_get_set -- --nocapture

--show-output 参数只会显示失败测试的输出,而且会在测试结果之后整齐地显示。这样可以更容易地将测试结果与输出区分开。通过测试就不会有输出了。

测试里除了 println!,还可以使用 dbg!(),会自动打印表达式和计算后的值。比如 dbg!(addr); 返回 [tests/server.rs:15:5] addr = 127.0.0.1:57813

命令行工具

Rust 中最常见的命令行参数解析库是 clap。我们看 redis client 的设计,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#[derive(Parser, Debug)]
#[command(
name = "mini-redis-cli",
version,
author,
about = "Issue Redis commands"
)]
struct Cli {
#[clap(subcommand)]
command: Command,

#[arg(id = "hostname", long, default_value = "127.0.0.1")]
host: String,

#[arg(long, default_value_t = DEFAULT_PORT)]
port: u16,
}

一个子命令(通过 Command 枚举表示),#[clap(subcommand)] 会告诉 clap 这是个子命令,加上他继承Subcommand,就会自动生成解析的代码。子命令里的参数字段,也可以增加标签。这个自动生成功能,极大的简化了命令解析的过程。
默认参数可以是用 --字段名,或者用字段位置对应。

1
2
3
4
5
6
7
8
#[derive(Subcommand, Debug)]
enum Command {
Ping { msg: Option<Bytes> },
Get { key: String },
Set { key: String, value: Bytes, expires: Option<Duration> },
Publish { channel: String, message: Bytes },
Subscribe { channels: Vec<String> },
}