egg 浅析
主要分析egraphs-good也就是egg这个库的实现机制.因为最近发现适配到基于relay的ir中存在一些问题,因此还是需要仔细研究一下他的实现细节.
1. Language
Language应该就是代表的是enode, 由一个op以及若干children组成.
他的构建机制每个dsl中自己实现from_op
函数.
pub struct Symbol(u32);
pub struct Id(u32);
pub struct SymbolLang {
/// The operator for an enode
pub op: Symbol,
/// The enode's children `Id`s
pub children: Vec<Id>,
}
impl FromOp for SymbolLang {
type Error = Infallible;
fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
Ok(Self {
op: op.into(),
children,
})
}
}SymbolLang
中op对应的类型是symbol (u32)
,他的内部维护了一个string
hashset,然后调用op.into()
从hashset中取得对应的index作为symbol.
这里的children的类型是id (u32)
, 本意是表示eclass的id,
但如果没有加入egraph之前实际上是共用symobl的值.
2. RecExpr
RecExpr表示是一组由用户定义的language组成的递归的expression,
比如我构建输入a + b
, 那么此时RecExpr的nodes由
[+, a, b]
,[a]
,[b]
三个language节点组成.
即保存了输入表达式下的所有 language node.
pub struct RecExpr<L> { |
他是通过递归的parser构建的,每次解析一个 language node然后放入到RecExpr中去.
impl<L: FromOp> FromStr for RecExpr<L> { |
3. EClass
EClass定义.
pub struct EClass<L, D> { |
4. EGraph
pub struct EGraph<L: Language, N: Analysis<L>> { |
5. Rewrite
pub struct Rewrite<L, N> { |
5.1 rebuild
egraph 首先添加进入后都是没有clean的,所以需要rebuild一次
5.1.2 收集class_by_op用于类型匹配.
egg的添加是,遍历这个eclass中所有的enode,然后enode把他所属的eclass
id存入discriminant
的key中.
let mut add = |n: &L| { |
因为rust的enum是可以提供完全不同结构的类型,
因此discriminant
就是映射他的结构类型到int key,
他的好处就是你可以添加很多很多不同类型的ir,这样基于类型的enode匹配就很简单的从字典中获取一个入口eclass开始匹配即可.
虽然如下所示,App可能存在很多个enode,但是至少能从类型上消除很大一部分的候选了.
[DEBUG egg::egraph] Add : App([93, 94]) class id : 54 into key : Discriminant(5) |
5.1 Rewrite Marco
通过一个rewrite!
的宏,将lhs,rhs构造成两部分.
pub struct Rewrite<L, N> {
/// The name of the rewrite.
pub name: Symbol,
/// 可以是从expr构建/ 也可以是自定义的匹配的方式
pub searcher: Arc<dyn Searcher<L, N> + Sync + Send>,
/// 获得对应的结果, 可以是expr也可以自定义构建
pub applier: Arc<dyn Applier<L, N> + Sync + Send>,
}
5.2 Pattern Match
这里的匹配是调用rewriter的search进行搜索.
首先这里的searcher是从字符串构造,
首先通过字符串解析为PatternAst
pub type PatternAst<L> = RecExpr<ENodeOrVar<L>>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
PatternAst::from_str(s).map(Self::from)
}PatternAst
构造出新的Pattern对象. impl<L: Language> Pattern<L> {
/// Creates a new pattern from the given pattern ast.
pub fn new(ast: PatternAst<L>) -> Self {
let ast = ast.compact();
let program = machine::Program::compile_from_pat(&ast);
Pattern { ast, program }
}
/// Returns a list of the [`Var`]s in this pattern.
pub fn vars(&self) -> Vec<Var> {
let mut vars = vec![];
for n in self.ast.as_ref() {
if let ENodeOrVar::Var(v) = n {
if !vars.contains(v) {
vars.push(*v)
}
}
}
vars
}
}
struct Machine {
reg: Vec<Id>,
// a buffer to re-use for lookups
lookup: Vec<Id>,
}
pub struct Program<L> {
instructions: Vec<Instruction<L>>,
subst: Subst,
}
pub(crate) fn compile_from_pat(pattern: &PatternAst<L>) -> Self {
let program = Compiler::new(pattern).compile();
log::debug!("Compiled {:?} to {:?}", pattern.as_ref(), program);
program
}
pattern的search是将自身的pattern ast转换为Enode,然后使用discriminant获取这个enode的op的类型,再从egraph中寻找所有的elass进行下一步匹配.
但是我还是没懂discriminant
是怎么获取的key是怎样的.官方文档上说此函数返回值只关心enum的类型,而不关心具体的值,这个就很奇怪了.
pub fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> { |