自動微分のための TypeTrees
TypeTrees とは?
Enzyme のためのメモリレイアウト記述子です。型がメモリ内でどのように構造化されているかを Enzyme に正確に伝えることで、導関数を効率的に計算できるようにします。
構造
#![allow(unused)]
fn main() {
TypeTree(Vec<Type>)
Type {
offset: isize, // バイトオフセット(-1 = どこでも)
size: usize, // バイト単位のサイズ
kind: Kind, // Float、Integer、Pointer など
child: TypeTree // ネストされた構造
}
}
例: fn compute(x: &f32, data: &[f32]) -> f32
入力 0: x: &f32
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // 単一の値: オフセット 0 を使用
child: TypeTree::new()
}])
}])
}
入力 1: data: &[f32]
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = すべての要素
child: TypeTree::new()
}])
}])
}
出力: f32
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // 単一のスカラー: オフセット 0 を使用
child: TypeTree::new()
}])
}
なぜ必要か?
- Enzyme は LLVM IR から複雑な型レイアウトを推論できない
- 低速なメモリパターン解析を防ぐ
- ネストされた構造に対する正しい導関数計算を可能にする
- どのバイトが微分可能で、どのバイトがメタデータかを Enzyme に伝える
Enzyme がこの情報を使って行うこと:
TypeTrees なし:
; Enzyme は汎用的な LLVM IR を見る:
define float @distance(ptr %p1, ptr %p2) {
; これらのポインターが何を指しているかを推測する必要がある
; すべてのメモリ操作を低速に解析する
; 最適化の機会を逃す可能性がある
}
TypeTrees あり:
define "enzyme_type"="{[-1]:Float@float}" float @distance(
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p1,
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p2
) {
; Enzyme は正確な型レイアウトを知っている
; 効率的な導関数コードを直接生成できる
}
TypeTrees - オフセットと -1 の解説
Type の構造
#![allow(unused)]
fn main() {
Type {
offset: isize, // この型が始まる場所
size: usize, // この型の大きさ
kind: Kind, // データの種類(Float、Int、Pointer)
child: TypeTree // 内部に含まれるもの(ポインター/コンテナー用)
}
}
オフセット値
通常のオフセット(0、4、8 など)
構造体内の特定のバイト位置
#![allow(unused)]
fn main() {
struct Point {
x: f32, // オフセット 0、サイズ 4
y: f32, // オフセット 4、サイズ 4
id: i32, // オフセット 8、サイズ 4
}
}
&Point の TypeTree(内部表現):
#![allow(unused)]
fn main() {
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // バイト 0 の x
Type { offset: 4, size: 4, kind: Float }, // バイト 4 の y
Type { offset: 8, size: 4, kind: Integer } // バイト 8 の id
])
}
LLVM を生成
"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer}"
オフセット -1(特殊: 「どこでも」)
「このパターンがすべての要素に対して繰り返される」ことを意味します
例 1: 直接配列 [f32; 100](ポインター間接参照なし)
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, // すべての位置
size: 4, // 各 f32 は 4 バイト
kind: Float, // すべての要素は float
}])
}
LLVM を生成: "enzyme_type"="{[-1]:Float@float}"
例 1b: 配列参照 &[f32; 100](ポインター間接参照あり)
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // すべての配列要素
size: 4, // 各 f32 は 4 バイト
kind: Float, // すべての要素は float
}])
}])
}
LLVM を生成: "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
オフセット 0,4,8,12...396 を持つ 100 個の個別の Type を列挙する代わりです
例 2: スライス &[i32]
#![allow(unused)]
fn main() {
// スライスデータへのポインター
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // すべてのスライス要素
size: 4, // 各 i32 は 4 バイト
kind: Integer
}])
}])
}
LLVM を生成: "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"
例 3: 混合構造
#![allow(unused)]
fn main() {
struct Container {
header: i64, // オフセット 0
data: [f32; 1000], // オフセット 8、ただし要素には -1 を使用
}
}
#![allow(unused)]
fn main() {
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // すべての配列要素
}])
}
])
}
重要な違い: 単一の値と配列
単一の値では、精度のためにオフセット 0 を使用します:
&f32はオフセット 0 にちょうど 1 つの f32 値を持つ- -1(「どこでも」)を使用するよりも正確
- 生成結果:
{[-1]:Pointer, [-1,0]:Float@float}
配列では、効率のためにオフセット -1 を使用します:
&[f32; 100]は同じパターンが 100 回繰り返される- -1 を使用することで、100 個の個別のオフセットの列挙を避けられる
- 生成結果:
{[-1]:Pointer, [-1,-1]:Float@float}