过程宏

手工定义图

过程宏要比声明宏要复杂很多,不过无论是哪一种过程宏,本质都是一样的,都涉及要把 输入的 TokenStream 处理成输出的 TokenStream

Cargo.toml添加proc-macro声明

这样,编译器才允许你使用 #[proc_macro] 相关的宏。

[lib]
proc-macro = true

函数宏

  • #[proc_macro]

和macro_rules! 功能类似,但更为强大。

src/lib.rs:定义过程函数宏示例:可以看到,都是处理TokenStream

mod builder;
mod builder_with_attr;
mod raw_builder;

use proc_macro::TokenStream;
use raw_builder::BuilderContext;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro]
pub fn query(input: TokenStream) -> TokenStream {
    // 只有修改代码之后再次编译才会执行
    println!("{:#?}", input);
    "fn hello() { println!(\"Hello world!\"); }"
        .parse()
        .unwrap()
}

#[proc_macro_derive(RawBuilder)]
pub fn derive_raw_builder(input: TokenStream) -> TokenStream {
    // 只有修改代码之后再次编译才会执行
    println!("{:#?}", input);
    BuilderContext::render(input).unwrap().parse().unwrap()
}

#[proc_macro_derive(Builder)]
pub fn derive_builder(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    println!("{:#?}", input);
    builder::BuilderContext::from(input).render().into()
}

#[proc_macro_derive(BuilderWithAttr, attributes(builder))]
pub fn derive_builder_with_attr(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    println!("{:#?}", input);
    builder_with_attr::BuilderContext::from(input)
        .render()
        .into()
}

TokenStream

使用者可以通过 query!(…) 来调用。我们打印传入的 TokenStream, 然后把一段包含在字符串中的代码解析成 TokenStream 返回。

这里可以非常方便地用字符串的 parse() 方法来获得 TokenStream, 是因为 TokenStream 实现了 FromStr trait。

examples/query.rs使用过冲示例

use macros::query;

fn main() {
    // query!(SELECT * FROM users WHERE age > 10);
    query!(SELECT * FROM users u JOIN (SELECT * from profiles p) WHERE u.id = p.id);
    hello()
}
  1. .parse().unwrap(): 字符串自动转为TokenStream类型
cargo run --example query > examples/query_output.txt

运行结果示例:可以看到打印出来的TokenStream

TokenStream [
    Ident {
        ident: "SELECT",
        span: #0 bytes(94..100),
    },
    Punct {
        ch: '*',
        spacing: Alone,
        span: #0 bytes(101..102),
    },
    Ident {
        ident: "FROM",
        span: #0 bytes(103..107),
    },
    Ident {
        ident: "users",
        span: #0 bytes(108..113),
    },
    Ident {
        ident: "u",
        span: #0 bytes(114..115),
    },
    Ident {
        ident: "JOIN",
        span: #0 bytes(116..120),
    },
    Group {
        delimiter: Parenthesis,
        stream: TokenStream [
            Ident {
                ident: "SELECT",
                span: #0 bytes(122..128),
            },
            Punct {
                ch: '*',
                spacing: Alone,
                span: #0 bytes(129..130),
            },
            Ident {
                ident: "from",
                span: #0 bytes(131..135),
            },
            Ident {
                ident: "profiles",
                span: #0 bytes(136..144),
            },
            Ident {
                ident: "p",
                span: #0 bytes(145..146),
            },
        ],
        span: #0 bytes(121..147),
    },
    Ident {
        ident: "WHERE",
        span: #0 bytes(148..153),
    },
    Ident {
        ident: "u",
        span: #0 bytes(154..155),
    },
    Punct {
        ch: '.',
        spacing: Alone,
        span: #0 bytes(155..156),
    },
    Ident {
        ident: "id",
        span: #0 bytes(156..158),
    },
    Punct {
        ch: '=',
        spacing: Alone,
        span: #0 bytes(159..160),
    },
    Ident {
        ident: "p",
        span: #0 bytes(161..162),
    },
    Punct {
        ch: '.',
        spacing: Alone,
        span: #0 bytes(162..163),
    },
    Ident {
        ident: "id",
        span: #0 bytes(163..165),
    },
]
Hello world!

TokenStream是一个Iterator,里面包含一系列的TokenTree

#![allow(unused)]
fn main() {
pub enum TokenTree {
    // 组,如果代码中包含括号,比如{} [] <> () ,那么内部的内容会被分析成一个Group(组)
    Group(Group), 
    // 标识符
    Ident(Ident),
    // 标点符号 
    Punct(Punct),
    // 字面量 
    Literal(Literal), 
}
}

Group Example

use macros::query;

fn main() {
    // query!(SELECT * FROM users WHERE age > 10);
    query!(SELECT * FROM users u JOIN (SELECT * from profiles p) WHERE u.id = p.id);
    hello()
}

派生宏:

#[proc_macro_devive(DeriveMacroName)]

用于结构体(struct)、枚举(enum)、联合(union)类型,可为其实现函数或特征(Trait)

常用派生宏

#[derive(Debug)]

手工实现builder模式

实现效果:链式调用

想到达到链式调用的效果

fn main() {
    let command = Command::builder()
        .executable("cargo".to_owned())
        .args(vec!["build".to_owned(), "--release".to_owned()])
        .env(vec![])
        .build()
        .unwrap();
    assert!(command.current_dir.is_none());

    let command = Command::builder()
        .executable("cargo".to_owned())
        .args(vec!["build".to_owned(), "--release".to_owned()])
        .env(vec![])
        .current_dir("..".to_owned())
        .build()
        .unwrap();
    assert!(command.current_dir.is_some());
    println!("{:?}", command);
}

可以这样定义

/// 这是command.rs的派生宏实现的代码样子
#[allow(dead_code)]
#[derive(Debug)]
pub struct Command {
    executable: String,
    args: Vec<String>,
    env: Vec<String>,
    current_dir: Option<String>,
}

#[derive(Debug, Default)]
pub struct CommandBuilder {
    executable: Option<String>,
    args: Option<Vec<String>>,
    env: Option<Vec<String>>,
    current_dir: Option<String>,
}

impl Command {
    pub fn builder() -> CommandBuilder {
        Default::default()
    }
}

impl CommandBuilder {
    pub fn executable(mut self, v: String) -> Self {
        self.executable = Some(v.to_owned());
        self
    }

    pub fn args(mut self, v: Vec<String>) -> Self {
        self.args = Some(v.to_owned());
        self
    }

    pub fn env(mut self, v: Vec<String>) -> Self {
        self.env = Some(v.to_owned());
        self
    }

    pub fn current_dir(mut self, v: String) -> Self {
        self.current_dir = Some(v.to_owned());
        self
    }

    pub fn build(mut self) -> Result<Command, &'static str> {
        Ok(Command {
            executable: self.executable.take().ok_or("executable must be set")?,
            args: self.args.take().ok_or("args must be set")?,
            env: self.env.take().ok_or("env must be set")?,
            current_dir: self.current_dir.take(),
        })
    }

但是有点繁琐,可以使用派生宏派生出这些代码

派生宏思路

要生成的代码模版: 把输入的 TokenStream 抽取出来,也就是把在 struct 的定义内部,每个域的名字及其类型都抽出来,然后生成对应的方法代码。

impl {{ name }} {
    pub fn builder() -> {{ builder_name }} {
        Default::default()
    }
}

#[derive(Debug, Default)]
pub struct {{ builder_name }} {
    {% for field in fields %}
        {{ field.name }}: Option<{{ field.ty }}>,
    {% endfor %}
}

impl {{ builder_name }} {
    {% for field in fields %}
    pub fn {{ field.name }}(mut self, v: impl Into<{{ field.ty }}>) -> {{ builder_name }} {
        self.{{ field.name }} = Some(v.into());
        self
    }
    {% endfor %}

    pub fn build(self) -> Result<{{ name }}, &'static str> {
        Ok({{ name }} {
            {% for field in fields %}
                {% if field.optional %}
                {{ field.name }}: self.{{ field.name }},
                {% else %}
                {{ field.name }}: self.{{ field.name }}.ok_or("Build failed: missing {{ field.name }}")?,
                {% endif %}
            {% endfor %}
        })
    }
}

  1. 7-12: 这里的 fileds / builder_name 是我们要传入的参数,每个 field 还需要 name 和 ty 两个 属性,分别对应 field 的名字和类型
  2. 25-26: 对于原本是 Option 类型的域,要避免生成 Option

构建对应数据结构

/// 处理 jinja 模板的数据结构,在模板中我们使用了 name / builder_name / fields
#[derive(Template)]
#[template(path = "builder.j2", escape = "none")]
pub struct BuilderContext {
    name: String,
    builder_name: String,
    fields: Vec<Fd>,
}

src/lib.rs: 使用派生宏从TokenStream抽取出想要的信息

对于 derive macro,要使用 proce_macro_derive 这个宏。我们把这个 derive macro 命名为 Builder。

#[proc_macro_derive(RawBuilder)]
pub fn derive_raw_builder(input: TokenStream) -> TokenStream {
    // 只有修改代码之后再次编译才会执行
    println!("{:#?}", input);
    BuilderContext::render(input).unwrap().parse().unwrap()
}

examples/raw_command.rs: 使用这个派生宏抽取

use macros::RawBuilder;

#[allow(dead_code)]
#[derive(Debug, RawBuilder)]
pub struct Command {
    executable: String,
    args: Vec<String>,
    env: Vec<String>,
    current_dir: Option<String>,
}

fn main() {
    let command = Command::builder()
        .executable("cargo".to_owned())
        .args(vec!["build".to_owned(), "--release".to_owned()])
        .env(vec![])
        .build()
        .unwrap();
    assert!(command.current_dir.is_none());

    let command = Command::builder()
        .executable("cargo".to_owned())
        .args(vec!["build".to_owned(), "--release".to_owned()])
        .env(vec![])
        .current_dir("..".to_owned())
        .build()
        .unwrap();
    assert!(command.current_dir.is_some());
    println!("{:?}", command);
}

运行,查看获取的TokenStream

cargo run --example raw_command > examples/raw_command_output.txt
TokenStream [
    Punct {
        ch: '#',
        spacing: Alone,
        span: #0 bytes(25..26),
    },
    Group {
        delimiter: Bracket,
        stream: TokenStream [
            Ident {
                ident: "allow",
                span: #0 bytes(27..32),
            },
            Group {
                delimiter: Parenthesis,
                stream: TokenStream [
                    Ident {
                        ident: "dead_code",
                        span: #0 bytes(33..42),
                    },
                ],
                span: #0 bytes(32..43),
            },
        ],
        span: #0 bytes(26..44),
    },
    Ident {
        ident: "pub",
        span: #0 bytes(74..77),
    },
    Ident {
        ident: "struct",
        span: #0 bytes(78..84),
    },
    Ident {
        ident: "Command",
        span: #0 bytes(85..92),
    },
    Group {
        delimiter: Brace,
        stream: TokenStream [
            Ident {
                ident: "executable",
                span: #0 bytes(99..109),
            },
            Punct {
                ch: ':',
                spacing: Alone,
                span: #0 bytes(109..110),
            },
            Ident {
                ident: "String",
                span: #0 bytes(111..117),
            },
            Punct {
                ch: ',',
                spacing: Alone,
                span: #0 bytes(117..118),
            },
            Ident {
                ident: "args",
                span: #0 bytes(123..127),
            },
            Punct {
                ch: ':',
                spacing: Alone,
                span: #0 bytes(127..128),
            },
            Ident {
                ident: "Vec",
                span: #0 bytes(129..132),
            },
            Punct {
                ch: '<',
                spacing: Alone,
                span: #0 bytes(132..133),
            },
            Ident {
                ident: "String",
                span: #0 bytes(133..139),
            },
            Punct {
                ch: '>',
                spacing: Joint,
                span: #0 bytes(139..140),
            },
            Punct {
                ch: ',',
                spacing: Alone,
                span: #0 bytes(140..141),
            },
            Ident {
                ident: "env",
                span: #0 bytes(146..149),
            },
            Punct {
                ch: ':',
                spacing: Alone,
                span: #0 bytes(149..150),
            },
            Ident {
                ident: "Vec",
                span: #0 bytes(151..154),
            },
            Punct {
                ch: '<',
                spacing: Alone,
                span: #0 bytes(154..155),
            },
            Ident {
                ident: "String",
                span: #0 bytes(155..161),
            },
            Punct {
                ch: '>',
                spacing: Joint,
                span: #0 bytes(161..162),
            },
            Punct {
                ch: ',',
                spacing: Alone,
                span: #0 bytes(162..163),
            },
            Ident {
                ident: "current_dir",
                span: #0 bytes(168..179),
            },
            Punct {
                ch: ':',
                spacing: Alone,
                span: #0 bytes(179..180),
            },
            Ident {
                ident: "Option",
                span: #0 bytes(181..187),
            },
            Punct {
                ch: '<',
                spacing: Alone,
                span: #0 bytes(187..188),
            },
            Ident {
                ident: "String",
                span: #0 bytes(188..194),
            },
            Punct {
                ch: '>',
                spacing: Joint,
                span: #0 bytes(194..195),
            },
            Punct {
                ch: ',',
                spacing: Alone,
                span: #0 bytes(195..196),
            },
        ],
        span: #0 bytes(93..198),
    },
]
Command { executable: "cargo", args: ["build", "--release"], env: [], current_dir: Some("..") }

打印信息说明

  1. 首先有一个 Group,包含了 #[allow(dead_code)] 属性的信息。因为我们现在拿到 的 derive 下的信息,所以所有不属于 #[derive(…)] 的属性,都会被放入 TokenStream 中。

  2. 之后是 pub / struct / Command 三个 ident。

  3. 随后又是一个 Group,包含了每个 field 的信息。我们看到,field 之间用逗号这个 Punct 分隔,field 的名字和类型又是通过冒号这个 Punct 分隔。而类型,可能是一个 Ident,如 String,或者一系列 Ident / Punct,如 Vec / < / String / >。

src/raw_builder.rs: 使用anyhow与askama抽取TokenStream中的信息

我们要做的就是,把这个 TokenStream 中的 struct 名字,以及每个 field 的名字和类型拿出来。 如果类型是 Option,那么把 T 拿出来,把 optional 设置为 true。

use anyhow::Result;
use askama::Template;
use proc_macro::{Ident, TokenStream, TokenTree};
use std::collections::VecDeque;

/// 处理 jinja 模板的数据结构,在模板中我们使用了 name / builder_name / fields
#[derive(Template)]
#[template(path = "builder.j2", escape = "none")]
pub struct BuilderContext {
    name: String,
    builder_name: String,
    fields: Vec<Fd>,
}

/// 描述 struct 的每个 field
#[derive(Debug, Default)]
struct Fd {
    name: String,
    ty: String,
    optional: bool,
}

templates/builder.j2: 上面askama用到的jinja2模版

impl {{ name }} {
    pub fn builder() -> {{ builder_name }} {
        Default::default()
    }
}

#[derive(Debug, Default)]
pub struct {{ builder_name }} {
    {% for field in fields %}
        {{ field.name }}: Option<{{ field.ty }}>,
    {% endfor %}
}

impl {{ builder_name }} {
    {% for field in fields %}
    pub fn {{ field.name }}(mut self, v: impl Into<{{ field.ty }}>) -> {{ builder_name }} {
        self.{{ field.name }} = Some(v.into());
        self
    }
    {% endfor %}

    pub fn build(self) -> Result<{{ name }}, &'static str> {
        Ok({{ name }} {
            {% for field in fields %}
                {% if field.optional %}
                {{ field.name }}: self.{{ field.name }},
                {% else %}
                {{ field.name }}: self.{{ field.name }}.ok_or("Build failed: missing {{ field.name }}")?,
                {% endif %}
            {% endfor %}
        })
    }
}

src/raw_builder.rs: 实现对应抽取方法

impl Fd {
    /// name 和 field 都是通过冒号 Punct 切分出来的 TokenTree 切片
    pub fn new(name: &[TokenTree], ty: &[TokenTree]) -> Self {
        // 把类似 Ident("Option"), Punct('<'), Ident("String"), Punct('>) 的 ty
        // 收集成一个 String 列表,如 vec!["Option", "<", "String", ">"]
        let ty = ty
            .iter()
            .map(|v| match v {
                TokenTree::Ident(n) => n.to_string(),
                TokenTree::Punct(p) => p.as_char().to_string(),
                e => panic!("Expect ident, got {:?}", e),
            })
            .collect::<Vec<_>>();
        // 冒号前最后一个 TokenTree 是 field 的名字
        // 比如:executable: String,
        // 注意这里不应该用 name[0],因为有可能是 pub executable: String
        // 甚至,带 attributes 的 field,
        // 比如:#[builder(hello = world)] pub executable: String
        match name.last() {
            Some(TokenTree::Ident(name)) => {
                // 如果 ty 第 0 项是 Option,那么从第二项取到倒数第一项
                // 取完后上面的例子中的 ty 会变成 ["String"],optiona = true
                let (ty, optional) = if ty[0].as_str() == "Option" {
                    (&ty[2..ty.len() - 1], true)
                } else {
                    (&ty[..], false)
                };
                Self {
                    name: name.to_string(),
                    ty: ty.join(""), // 把 ty join 成字符串
                    optional,
                }
            }
            e => panic!("Expect ident, got {:?}", e),
        }
    }
}

impl BuilderContext {
    /// 从 TokenStream 中提取信息,构建 BuilderContext
    fn new(input: TokenStream) -> Self {
        let (name, input) = split(input);
        let fields = get_struct_fields(input);
        Self {
            builder_name: format!("{}Builder", name),
            name: name.to_string(),
            fields,
        }
    }

    /// 把模板渲染成字符串代码
    pub fn render(input: TokenStream) -> Result<String> {
        let template = Self::new(input);
        Ok(template.render()?)
    }
}

/// 把 TokenStream 分出 struct 的名字,和包含 fields 的 TokenStream
fn split(input: TokenStream) -> (Ident, TokenStream) {
    let mut input = input.into_iter().collect::<VecDeque<_>>();
    // 一直往后找,找到 struct 停下来
    while let Some(item) = input.pop_front() {
        if let TokenTree::Ident(v) = item {
            if v.to_string() == "struct" {
                break;
            }
        }
    }

    // struct 后面,应该是 struct name
    let ident;
    if let Some(TokenTree::Ident(v)) = input.pop_front() {
        ident = v;
    } else {
        panic!("Didn't find struct name");
    }

    // struct 后面可能还有若干 TokenTree,我们不管,一路找到第一个 Group
    let mut group = None;
    for item in input {
        if let TokenTree::Group(g) = item {
            group = Some(g);
            break;
        }
    }

    (ident, group.expect("Didn't find field group").stream())
}

/// 从包含 fields 的 TokenStream 中切出来一个个 Fd
fn get_struct_fields(input: TokenStream) -> Vec<Fd> {
    let input = input.into_iter().collect::<Vec<_>>();
    input
        .split(|v| match v {
            // 先用 ',' 切出来一个个包含 field 所有信息的 &[TokenTree]
            TokenTree::Punct(p) => p.as_char() == ',',
            _ => false,
        })
        .map(|tokens| {
            tokens
                .split(|v| match v {
                    // 再用 ':' 把 &[TokenTree] 切成 [&[TokenTree], &[TokenTree]]
                    // 它们分别对应名字和类型
                    TokenTree::Punct(p) => p.as_char() == ':',
                    _ => false,
                })
                .collect::<Vec<_>>()
        })
        // 正常情况下,应该得到 [&[TokenTree], &[TokenTree]],对于切出来长度不为 2 的统统过滤掉
        .filter(|tokens| tokens.len() == 2)
        // 使用 Fd::new 创建出每个 Fd
        .map(|tokens| Fd::new(tokens[0], tokens[1]))
        .collect()
}

提示:类比理解

可以对着打印出来的 TokenStream 和刚才的分析进行理解。 核心的就是 get_struct_fields() 方法,如果觉得难懂, 可以想想如果要把一个 a=1,b=2 的字符串切成 [[a, 1], [b, 2]] 该怎么做,就很容易理解了。

自动实现:使用syn/quote可以不用自己定义模版

详见上方对比图

过程属性宏: proc_macro_derive(macro_name, attributes(attr_name))

用于属性宏, 用在结构体、字段、函数等地方,为其指定属性等功能, 类似python的计算属性

定义结构体时在某个字段上方使用对应attr_name

#![allow(unused)]
fn main() {
#[allow(dead_code)]
#[derive(Debug, BuilderWithAttr)]
pub struct Command {
    executable: String,
    #[builder(each = "arg")]
    args: Vec<String>,
    #[builder(each = "env", default = "vec![]")]
    env: Vec<String>,
    current_dir: Option<String>,
}
}

使用syn/quote定义属性宏

详见上方对比图