BN IL 层解决多条件控制的双分支间接跳转

样本

在分析 libxyass.so 的过程中,遇到了含有多个 csel 语句控制的 br 跳转

左图汇编层可以看出在 br 跳转语句前经过了三个 csel 条件判断

csel 指令的模式是 csel dest, src1, src2, condition,如果 condition 成立则 dst = src1,否则 dst = src2

观察发现这三个判断的条件都是相同的 ne, 可以推测出这个混淆其实是由一个 if else 语句变换而来的

并且右边的 mlil 层更清晰地说明了这一点

bn 把所有的条件都处理成了 cond 的形式,比去看汇编的各种条件码更容易分析,因为即使在汇编层是不同的条件指令,经过 mlil 处理后也可能是相同的条件表达式

这一处间接跳转有三个 cond, 分别是 x23 != 0, x28 != 0,x27 != 0,再往前看这三个变量,发现它们被初始化为同一个来自 var_4b1 的值

所以本质上这三个条件是等价的,都是在判断 var_4b1 != 0,而最终跳转的目标地址也只有两个,当条件为真时的 TrueAddr 和条件为假时的 FalseAddr

mlil 层处理

选择在 mlil 层处理的原因是 patch 汇编需要顾及的问题太多,如果遇到复杂的混淆,容易遇到修改空间不足或者坏程序的控制流或数据流等问题,而直接修改 il 的话可以随意添加任意数量的指令来实现想要的功能,并且 mlil 也更易阅读和理解逻辑

同时 il 层还能看 SSA 形式的变量,方便做数据流追踪

SSA (静态单赋值)是指每个变量只被赋值一次,每次赋值都会创建一个新的变量版本

BN 中的变量

1
2
3
4
5
6
7
8
9
10
[底层变量]
Variable

[SSA变量]
SSAVariable

[IL表达式节点]
MediumLevelILVar / MediumLevelILVarSsa
MediumLevelILSetVar / MediumLevelILSetVarSsa
MediumLevelILVarPhi

其中 MediumLevelILVar 读取普通变量(非 SSA),MediumLevelILVarSsa 读取 SSA 变量,MediumLevelILSetVar 给普通变量赋值,MediumLevelILSetVarSsa 给 SSA 变量赋值

解决思路

要解决这个间接跳转,核心要解决两个问题:

  • 计算出条件为真时和条件为假时的跳转地址
  • 修改 mlil 视图,把多个条件合并成一个条件,并且把两个跳转地址都放在这个条件的 TrueAddr 和 FalseAddr 上

具体实现

0x1

把每个 SSA 变量对应到它的来源表达式(或 phi 节点),同时建立地址到指令的索引

1
2
3
4
5
6
7
8
9
10
11
ssa_defs = {}
addr_insn_ssa = {}
addr_insn_non_ssa = {}

for bb in mlil_ssa.basic_blocks:
for insn in bb:
addr_insn_ssa[insn.address] = insn
if insn.operation == Op.MLIL_SET_VAR_SSA:
ssa_defs[insn.dest] = insn.src
elif insn.operation == Op.MLIL_VAR_PHI:
ssa_defs[insn.dest] = insn

ssa_defs 记录 SSA 变量的定义表达式,addr_insn_ssaaddr_insn_non_ssa 分别记录 SSA 和非 SSA 形式的指令地址到指令对象的映射

遍历 mlil ssa,如果是普通赋值比如 x8#3 = expr 就记录 ssa_defs[x8#3] = expr,如果是 phi 函数 x8#4 = phi(x8#2, x8#3) 就记录 ssa_defs[x8#4] = phi(x8#2, x8#3)

phi 函数是 SSA 形式中用来合并不同控制流路径上同一变量的不同版本的特殊函数,表示在某个控制流汇合点上变量的值可能来自多个不同前驱

0x2

把 mlil 的表达式标准化为一个 key 的形式,方便后续比较不同条件表达式是否等价

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
UNARY_KEY_OPS = {Op.MLIL_ZX: "zx", Op.MLIL_SX: "sx", Op.MLIL_LOW_PART: "low_part"}
BINARY_KEY_OPS = {
Op.MLIL_ADD,
Op.MLIL_SUB,
Op.MLIL_AND,
Op.MLIL_OR,
Op.MLIL_XOR,
Op.MLIL_LSL,
Op.MLIL_LSR,
Op.MLIL_ASR,
Op.MLIL_MUL,
}
CMP_KEY_OPS = {
Op.MLIL_CMP_E,
Op.MLIL_CMP_NE,
Op.MLIL_CMP_SLT,
Op.MLIL_CMP_ULT,
Op.MLIL_CMP_SLE,
Op.MLIL_CMP_ULE,
Op.MLIL_CMP_SGE,
Op.MLIL_CMP_UGE,
Op.MLIL_CMP_SGT,
Op.MLIL_CMP_UGT,
}


def _safe_const(expr):
return getattr(expr, "constant", None)


def var_key(v, memo=None, visiting=None, depth=0):
memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting
k0 = ("varssa", str(v))
if k0 in memo:
return memo[k0]
if depth > MAX_DEPTH:
return ("varssa-depth", str(v))
if str(v) in visiting:
return ("varssa-rec", str(v))
node = ssa_defs.get(v)
if node is None:
return ("varssa-undef", str(v))
visiting.add(str(v))
try:
if getattr(node, "operation", None) == Op.MLIL_VAR_PHI:
src_keys = tuple(
sorted(str(var_key(s, memo, visiting, depth + 1)) for s in node.src)
)
out = ("phi", src_keys)
else:
out = expr_key(node, memo, visiting, depth + 1)
memo[k0] = out
return out
finally:
visiting.discard(str(v))


def expr_key(expr, memo=None, visiting=None, depth=0):
if expr is None:
return None
memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting
if depth > MAX_DEPTH:
return ("expr-depth", str(expr))
op = expr.operation
if op in CMP_KEY_OPS:
return (
"cmp",
op,
expr_key(expr.left, memo, visiting, depth + 1),
expr_key(expr.right, memo, visiting, depth + 1),
expr.size,
)
if op == Op.MLIL_VAR_SSA:
return var_key(expr.src, memo, visiting, depth + 1)
if op == Op.MLIL_VAR:
return ("var", str(expr.src), expr.size)
if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return ("const", _safe_const(expr), expr.size)
if op in UNARY_KEY_OPS:
return (
UNARY_KEY_OPS[op],
expr_key(expr.src, memo, visiting, depth + 1),
expr.size,
)
if op in BINARY_KEY_OPS:
return (
"binop",
op,
expr_key(expr.left, memo, visiting, depth + 1),
expr_key(expr.right, memo, visiting, depth + 1),
expr.size,
)
if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
return ("load", expr_key(expr.src, memo, visiting, depth + 1), expr.size)
return (op, str(expr), getattr(expr, "size", None))

对于 mlil 的结构,可以理解为类似 AST 的树形结构,每个节点是 exproperation 是节点类型,operands 是子节点

传入 的 exprMediumLevelILInstruction 或者是它的某个具体子类型对象比如 MediumLevelILCmpNeMediumLevelILVarSsaMediumLevelILConstMediumLevelILAdd 等等

然后对 expr 进行类型匹配和拆分,形成一个个结构相同的 tuple,这里还做了一个处理,就是当遇到 MLIL_VAR_SSA 的时候,不直接把变量名作为 key 的一部分,而是继续追踪这个变量的定义表达式,直到追踪到一个非 SSA 变量或者一个无法继续追踪的表达式为止,这样就能把那些虽然变量名不同但是表达式结构相同的条件归为一类,并且加上了防止死循环的处理

0x3

获取间接跳转地址

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
UNRESOLVED_TAG_TYPE = "Unresolved Indirect Control Flow" 
UNRESOLVED_TAG_DATA = {
"Jump to computed location set with invalid target",
"Jump to unhandled/unknown possible value set"
}

def find_unresolved_jump_sites(func, mlil_func):
hits = []
for bb in mlil_func.basic_blocks:
for insn in bb:
try:
tags = func.get_tags_at(insn.address, auto=None)
except:
tags = []
for tag in tags:
tname = getattr(getattr(tag, "type", None), "name", "")
tdata = str(getattr(tag, "data", ""))
if (
tname == UNRESOLVED_TAG_TYPE
and tdata in UNRESOLVED_TAG_DATA
and insn.operation == Op.MLIL_JUMP
):
hits.append((insn.address, insn, tname, tdata))
break
return hits


jump_hits = find_unresolved_jump_sites(f, mlil_ssa)
if not jump_hits:
raise Exception("no MLIL_JUMP found")
JMP_ADDR, jump_insn_ssa, _, _ = jump_hits[0]

bn 中没有解出来的跳转地址前面会被打上一个 tag

所以我们就可以通过遍历 mlil 查找指定 tag 的方式来找到我们要解决的间接跳转的地址,当某处指令被打上了目标 tag 并且指令类型是 MLIL_JUMP 的时候就将对应地址记录下来,如果当前函数有多个符合条件的跳转地址,先只处理第一个命中的 unresolved jump

0x4

从所有的 if 条件中,确定 main_cond

1
2
3
4
5
6
7
8
9
10
cond_groups = {}
cond_first_addr = {}
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_IF:
k = expr_key(insn.condition)
cond_groups.setdefault(k, []).append(insn.condition)
cond_first_addr.setdefault(k, insn.address)
main_cond_key = max(cond_groups, key=lambda k: len(cond_groups[k]))
main_cond_addr = cond_first_addr[main_cond_key]

对上面 expr_key 之后的结构进行匹配分组,把结构相同的条件归为一组,并且记录这组条件第一次出现的地址,如果有多个条件分组,选择出现次数最多的那个作为 main_cond,选用第一个地址作为 main_cond 的地址

这个 main_cond 就是后续将多个条件合并成一个时选用的 cond

0x5

计算 SSA 变量在 main_cond 下 true/false 两条路径上的取值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def make_tf(t, f):
return None if t is None or f is None else t if t == f else {"t": t, "f": f}

def eval_var_tf(v, memo):
k = ("var", str(v))
if k in memo:
return memo[k]
node = ssa_defs.get(v)
if node is None:
return None
if getattr(node, "operation", None) == Op.MLIL_VAR_PHI and len(node.src) == 2:
tv = fv = make_tf(None, None)
tv = eval_var_tf(node.src[0], memo)
fv = eval_var_tf(node.src[1], memo)
out = (
None
if tv is None or fv is None
else make_tf(tf_pick_branch(tv, True), tf_pick_branch(fv, False))
)
else:
out = eval_expr_tf(node, memo)
memo[k] = out
return out

如果该变量由含有两个来源的 phi 函数定义,默认把第一个来源当作 true 分支的值,第二个来源当作 false 分支的值,这个假设在大多数情况下是成立的,从 CFG 中也能看出来

main_cond 的真假条件绑定到 true/false 分支路径

1
2
3
4
5
6
7
8
def eval_condition_expr_tf(expr, memo):
if expr_key(expr) == main_cond_key:
return {"t": True, "f": False}
if expr.operation in CMP_MAP:
a = eval_expr_tf(expr.left, memo)
b = eval_expr_tf(expr.right, memo)
return tf_cmp(a, b, CMP_MAP[expr.operation])
return None

如果当前正在求值的这个条件表达式,归一化以后正好等于 main_cond_key,那么就直接规定,在 true 路径下,它的值是 True,在 false 路径下,它的值是 False

0x6

变量如果不是 phi ,就用求值器递归求值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def as_tf_pair(v):
return (v["t"], v["f"]) if is_tf(v) else (v, v)

def tf_binop(a, b, fn, size):
at, af = as_tf_pair(a)
bt, bf = as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return make_tf(fn(at, bt) & mask(size), fn(af, bf) & mask(size))

def tf_cmp(a, b, pred):
at, af = as_tf_pair(a)
bt, bf = as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return make_tf(bool(pred(at, bt)), bool(pred(af, bf)))

def tf_bool_to_int(v):
if v is None:
return None
if isinstance(v, bool):
return 1 if v else 0
if is_tf(v):
return make_tf(1 if v["t"] else 0, 1 if v["f"] else 0)
return None

CMP_MAP = {
Op.MLIL_CMP_E: operator.eq,
Op.MLIL_CMP_NE: operator.ne,
Op.MLIL_CMP_ULT: operator.lt,
Op.MLIL_CMP_ULE: operator.le,
Op.MLIL_CMP_UGT: operator.gt,
Op.MLIL_CMP_UGE: operator.ge,
}
BINOP_MAP = {
Op.MLIL_ADD: operator.add,
Op.MLIL_SUB: operator.sub,
Op.MLIL_AND: operator.and_,
Op.MLIL_OR: operator.or_,
Op.MLIL_XOR: operator.xor,
Op.MLIL_LSL: operator.lshift,
Op.MLIL_LSR: operator.rshift,
Op.MLIL_MUL: operator.mul,
}

def eval_expr_tf(expr, memo):
if expr is None:
return None
op = expr.operation
if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return expr.constant & mask(expr.size)
if op == Op.MLIL_VAR_SSA:
return eval_var_tf(expr.src, memo)
if op in BINOP_MAP:
return tf_binop(
eval_expr_tf(expr.left, memo),
eval_expr_tf(expr.right, memo),
BINOP_MAP[op],
expr.size,
)
if op == Op.MLIL_ASR:

def _asr(x, y):
bits = expr.size * 8
sb = 1 << (bits - 1)
x &= (1 << bits) - 1
x -= (1 << bits) if x & sb else x
return x >> y

return tf_binop(
eval_expr_tf(expr.left, memo),
eval_expr_tf(expr.right, memo),
_asr,
expr.size,
)
if op in CMP_MAP:
return tf_bool_to_int(eval_condition_expr_tf(expr, memo))
if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
ptrs = eval_expr_tf(expr.src, memo)
pt, pf = as_tf_pair(ptrs)
if pt is None or pf is None:
return None
try:
return make_tf(read_u(pt, expr.size), read_u(pf, expr.size))
except:
return None
if op == Op.MLIL_ZX:
return tf_map(eval_expr_tf(expr.src, memo), lambda v: v & mask(expr.size))
if op == Op.MLIL_SX:
return tf_map(
eval_expr_tf(expr.src, memo), lambda v: sext(v, expr.src.size, expr.size)
)
if op == Op.MLIL_LOW_PART:
return tf_map(eval_expr_tf(expr.src, memo), lambda v: v & mask(expr.size))
return None

这个就是核心的计算函数,两个映射表把 bn 中的比较操作和运算映射为 python 函数,同时维护两份值 tf,在表达式展开时,就能把 true/false 两条路径的结果一直向下传播

0x7

计算出 jump 处的两个目标地址

1
2
3
4
5
6
7
8
memo = {}
result_raw = eval_expr_tf(jump_insn_ssa.dest, memo)
if isinstance(result_raw, int):
result = {"t": result_raw, "f": result_raw}
elif is_tf(result_raw):
result = {"t": result_raw.get("t"), "f": result_raw.get("f")}
else:
raise Exception(f"unexpected result type: {result_raw!r}")

jump_insn_ssa.dest 就是该 MLIL_JUMP 的目的表达式

0x8

为间接跳转补充两个分支目标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@dataclass
class resolved_data:
cond_key: object
cond_addr: int
trueAddr: int
falseAddr: int

resolved_map = {
JMP_ADDR: resolved_data(main_cond_key, main_cond_addr, result["t"], result["f"])
}

def apply_user_indirect_branches(func, resolved_map):
for jmp_addr, item in resolved_map.items():
try:
func.set_user_indirect_branches(
jmp_addr,
[(func.arch, item.trueAddr), (func.arch, item.falseAddr)],
func.arch,
)
except Exception as e:
print(f"[!] failed {hex(jmp_addr)}: {e}")

apply_user_indirect_branches(f, resolved_map)
bv.update_analysis_and_wait()

构造 resolved_map,将修改 jump 需要的信息保存下来

set_user_indirect_branches() 用来告诉 bn 这个间接跳转在用户看来只有这几个目标分支

这里传入了两个分支 trueAddr 和 falseAddr,这样 bn 后续分析时,就能把这个 jump 当成只有两个目标的可解析控制流

bv.update_analysis_and_wait() 让 bn 重新分析

0x9

收集条件地址到条件表达式的映射

1
2
3
4
5
6
7
def build_addr_to_non_ssa_if_condition(old_mlil):
return {
insn.address: insn.condition
for bb in old_mlil.basic_blocks
for insn in bb
if insn.operation == Op.MLIL_IF
}

因为前面 main_cond_addr 是在 mlil_ssa 里选出来的,但是后续新的 if_expr 是在 mlil 层构建,也就是非 ssa 形式,所以需要找 mlil 层条件语句地址和表达式的映射

0xA

修改 bn mlil

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def convert_jump_to_if(func, resolved_map):
old_mlil = func.mlil
cond_by_addr = build_addr_to_non_ssa_if_condition(old_mlil)
new_func = MediumLevelILFunction(func.arch, low_level_il=func.llil)
new_func.prepare_to_copy_function(old_mlil)
for old_block in old_mlil.basic_blocks:
new_func.prepare_to_copy_block(old_block)
for instr_idx in range(old_block.start, old_block.end):
instr = old_mlil[instr_idx]
item = resolved_map.get(instr.address)
if item is None:
new_func.append(instr.copy_to(new_func), instr.source_location)
continue
try:
cond_expr = cond_by_addr.get(item.cond_addr)
if cond_expr is None:
raise RuntimeError(f"cannot find IF cond @ {hex(item.cond_addr)}")
indirect_branches = func.get_indirect_branches_at(instr.address)
if len(indirect_branches) != 2:
raise RuntimeError(f"indirect branches !=2 @ {hex(instr.address)}")
label_t, label_f = MediumLevelILLabel(), MediumLevelILLabel()
for branch in indirect_branches:
if branch.dest_addr == item.trueAddr:
label_t.operand = instr.targets[branch.dest_addr]
if branch.dest_addr == item.falseAddr:
label_f.operand = instr.targets[branch.dest_addr]
if (
getattr(label_t, "operand", None) is None
or getattr(label_f, "operand", None) is None
):
raise RuntimeError(f"cannot bind labels @ {hex(instr.address)}")
new_func.append(
new_func.if_expr(
cond_expr.copy_to(new_func),
label_t,
label_f,
instr.source_location,
),
instr.source_location,
)
except Exception as e:
print(f"[!] rewrite failed @ {hex(instr.address)}: {e}")
new_func.append(instr.copy_to(new_func), instr.source_location)
new_func.finalize()
new_func.generate_ssa_form()
return new_func


new_mlil = convert_jump_to_if(f, resolved_map)
print("[+] MLIL rewritten")

point 1

先准备旧的 mlil 和条件表,然后新建一个 mlil 函数,并逐块将旧指令全部复制到新函数里

point 2

1
2
3
cond_by_addr = build_addr_to_non_ssa_if_condition(old_mlil)
item = resolved_map.get(instr.address)
cond_expr = cond_by_addr.get(item.cond_addr)

这是在准备条件表达式,resolved_map 里保存了 main_cond 的地址,cond_by_addr 里保存了 mlil 层条件地址到条件表达式的映射,所以就能找到 main_cond 的条件表达式

point 3

1
2
3
indirect_branches = func.get_indirect_branches_at(instr.address)
if len(indirect_branches) != 2:
raise RuntimeError(f"indirect branches !=2 @ {hex(instr.address)}")

前面已经创建了两个分支,这里正常情况可以拿到目标分支

point 4

1
2
3
4
5
6
label_t, label_f = MediumLevelILLabel(), MediumLevelILLabel()
for branch in indirect_branches:
if branch.dest_addr == item.trueAddr:
label_t.operand = instr.targets[branch.dest_addr]
if branch.dest_addr == item.falseAddr:
label_f.operand = instr.targets[branch.dest_addr]

要把最后的结构改成 if 语句需要通过设置标签来实现,这里直接给新建的两个分支标签绑定上目标地址

point 5

1
2
3
4
5
6
7
8
9
new_func.append(
new_func.if_expr(
cond_expr.copy_to(new_func),
label_t,
label_f,
instr.source_location,
),
instr.source_location,
)

构造一条新的 if 指令,把原来的 jump(xxx) 替换成 if (cond_expr) true_label else false_label,改写失败就回退原始 jump

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
from binaryninja import *
from binaryninja.mediumlevelil import (
MediumLevelILFunction,
MediumLevelILLabel,
MediumLevelILOperation as Op,
)
from dataclasses import dataclass
import operator

MAX_DEPTH = 16
UNRESOLVED_TAG_TYPE = "Unresolved Indirect Control Flow"
UNRESOLVED_TAG_DATA = {
"Jump to computed location set with invalid target",
"Jump to unhandled/unknown possible value set",
}


@dataclass
class resolved_data:
cond_key: object
cond_addr: int
trueAddr: int
falseAddr: int


f = current_function
if f is None:
raise Exception("no current_function")
mlil_ssa = f.mlil.ssa_form
mlil_non_ssa = f.mlil
if mlil_ssa is None or mlil_non_ssa is None:
raise Exception("no mlil ssa/non-ssa")

ssa_defs = {}
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_SET_VAR_SSA:
ssa_defs[insn.dest] = insn.src
elif insn.operation == Op.MLIL_VAR_PHI:
ssa_defs[insn.dest] = insn


def mask(size):
return (1 << (size * 8)) - 1


def sext(v, src_size, dst_size):
bits = src_size * 8
s = 1 << (bits - 1)
v &= (1 << bits) - 1
return (v - (1 << bits) if v & s else v) & mask(dst_size)


def read_u(addr, size):
data = bv.read(addr, size)
if data is None or len(data) != size:
raise Exception(f"read failed @ {hex(addr)} size={size}")
return int.from_bytes(data, "little")


def fmt_val(v):
if v is None:
return "None"
if isinstance(v, bool):
return str(v)
if isinstance(v, int):
return hex(v)
if isinstance(v, dict):
return "{t: %s,f: %s}" % (fmt_val(v.get("t")), fmt_val(v.get("f")))
return str(v)


def is_tf(v):
return isinstance(v, dict) and ("t" in v or "f" in v)


def make_tf(t, f):
return None if t is None or f is None else t if t == f else {"t": t, "f": f}


def as_tf_pair(v):
return (v["t"], v["f"]) if is_tf(v) else (v, v)


def tf_pick_branch(v, b):
return v["t"] if is_tf(v) and b else v["f"] if is_tf(v) else v


def tf_map(v, fn):
vt, vf = as_tf_pair(v)
return None if vt is None or vf is None else make_tf(fn(vt), fn(vf))


def tf_binop(a, b, fn, size):
at, af = as_tf_pair(a)
bt, bf = as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return make_tf(fn(at, bt) & mask(size), fn(af, bf) & mask(size))


def tf_cmp(a, b, pred):
at, af = as_tf_pair(a)
bt, bf = as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return make_tf(bool(pred(at, bt)), bool(pred(af, bf)))


def tf_bool_to_int(v):
if v is None:
return None
if isinstance(v, bool):
return 1 if v else 0
if is_tf(v):
return make_tf(1 if v["t"] else 0, 1 if v["f"] else 0)
return None


UNARY_KEY_OPS = {Op.MLIL_ZX: "zx", Op.MLIL_SX: "sx", Op.MLIL_LOW_PART: "low_part"}
BINARY_KEY_OPS = {
Op.MLIL_ADD,
Op.MLIL_SUB,
Op.MLIL_AND,
Op.MLIL_OR,
Op.MLIL_XOR,
Op.MLIL_LSL,
Op.MLIL_LSR,
Op.MLIL_ASR,
Op.MLIL_MUL,
}
CMP_KEY_OPS = {
Op.MLIL_CMP_E,
Op.MLIL_CMP_NE,
Op.MLIL_CMP_SLT,
Op.MLIL_CMP_ULT,
Op.MLIL_CMP_SLE,
Op.MLIL_CMP_ULE,
Op.MLIL_CMP_SGE,
Op.MLIL_CMP_UGE,
Op.MLIL_CMP_SGT,
Op.MLIL_CMP_UGT,
}


def _safe_const(expr):
return getattr(expr, "constant", None)


def var_key(v, memo=None, visiting=None, depth=0):
memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting
k0 = ("varssa", str(v))
if k0 in memo:
return memo[k0]
if depth > MAX_DEPTH:
return ("varssa-depth", str(v))
if str(v) in visiting:
return ("varssa-rec", str(v))
node = ssa_defs.get(v)
if node is None:
return ("varssa-undef", str(v))
visiting.add(str(v))
try:
if getattr(node, "operation", None) == Op.MLIL_VAR_PHI:
src_keys = tuple(
sorted(str(var_key(s, memo, visiting, depth + 1)) for s in node.src)
)
out = ("phi", src_keys)
else:
out = expr_key(node, memo, visiting, depth + 1)
memo[k0] = out
return out
finally:
visiting.discard(str(v))


def expr_key(expr, memo=None, visiting=None, depth=0):
if expr is None:
return None
memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting
if depth > MAX_DEPTH:
return ("expr-depth", str(expr))
op = expr.operation
if op in CMP_KEY_OPS:
return (
"cmp",
op,
expr_key(expr.left, memo, visiting, depth + 1),
expr_key(expr.right, memo, visiting, depth + 1),
expr.size,
)
if op == Op.MLIL_VAR_SSA:
return var_key(expr.src, memo, visiting, depth + 1)
if op == Op.MLIL_VAR:
return ("var", str(expr.src), expr.size)
if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return ("const", _safe_const(expr), expr.size)
if op in UNARY_KEY_OPS:
return (
UNARY_KEY_OPS[op],
expr_key(expr.src, memo, visiting, depth + 1),
expr.size,
)
if op in BINARY_KEY_OPS:
return (
"binop",
op,
expr_key(expr.left, memo, visiting, depth + 1),
expr_key(expr.right, memo, visiting, depth + 1),
expr.size,
)
if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
return ("load", expr_key(expr.src, memo, visiting, depth + 1), expr.size)
return (op, str(expr), getattr(expr, "size", None))


def find_unresolved_jump_sites(func, mlil_func):
hits = []
for bb in mlil_func.basic_blocks:
for insn in bb:
try:
tags = func.get_tags_at(insn.address, auto=None)
except:
tags = []
for tag in tags:
tname = getattr(getattr(tag, "type", None), "name", "")
tdata = str(getattr(tag, "data", ""))
if (
tname == UNRESOLVED_TAG_TYPE
and tdata in UNRESOLVED_TAG_DATA
and insn.operation == Op.MLIL_JUMP
):
hits.append((insn.address, insn, tname, tdata))
break
return hits


jump_hits = find_unresolved_jump_sites(f, mlil_ssa)
if not jump_hits:
raise Exception("no MLIL_JUMP found")
JMP_ADDR, jump_insn_ssa, _, _ = jump_hits[0]

cond_groups = {}
cond_first_addr = {}
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_IF:
k = expr_key(insn.condition)
cond_groups.setdefault(k, []).append(insn.condition)
cond_first_addr.setdefault(k, insn.address)
main_cond_key = max(cond_groups, key=lambda k: len(cond_groups[k]))
main_cond_addr = cond_first_addr[main_cond_key]


CMP_MAP = {
Op.MLIL_CMP_E: operator.eq,
Op.MLIL_CMP_NE: operator.ne,
Op.MLIL_CMP_ULT: operator.lt,
Op.MLIL_CMP_ULE: operator.le,
Op.MLIL_CMP_UGT: operator.gt,
Op.MLIL_CMP_UGE: operator.ge,
}
BINOP_MAP = {
Op.MLIL_ADD: operator.add,
Op.MLIL_SUB: operator.sub,
Op.MLIL_AND: operator.and_,
Op.MLIL_OR: operator.or_,
Op.MLIL_XOR: operator.xor,
Op.MLIL_LSL: operator.lshift,
Op.MLIL_LSR: operator.rshift,
Op.MLIL_MUL: operator.mul,
}


def eval_var_tf(v, memo):
k = ("var", str(v))
if k in memo:
return memo[k]
node = ssa_defs.get(v)
if node is None:
return None
if getattr(node, "operation", None) == Op.MLIL_VAR_PHI and len(node.src) == 2:
tv = fv = make_tf(None, None)
tv = eval_var_tf(node.src[0], memo)
fv = eval_var_tf(node.src[1], memo)
out = (
None
if tv is None or fv is None
else make_tf(tf_pick_branch(tv, True), tf_pick_branch(fv, False))
)
else:
out = eval_expr_tf(node, memo)
memo[k] = out
return out


def eval_condition_expr_tf(expr, memo):
if expr_key(expr) == main_cond_key:
return {"t": True, "f": False}
if expr.operation in CMP_MAP:
a = eval_expr_tf(expr.left, memo)
b = eval_expr_tf(expr.right, memo)
return tf_cmp(a, b, CMP_MAP[expr.operation])
return None


def eval_expr_tf(expr, memo):
if expr is None:
return None
op = expr.operation
if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return expr.constant & mask(expr.size)
if op == Op.MLIL_VAR_SSA:
return eval_var_tf(expr.src, memo)
if op in BINOP_MAP:
return tf_binop(
eval_expr_tf(expr.left, memo),
eval_expr_tf(expr.right, memo),
BINOP_MAP[op],
expr.size,
)
if op == Op.MLIL_ASR:

def _asr(x, y):
bits = expr.size * 8
sb = 1 << (bits - 1)
x &= (1 << bits) - 1
x -= (1 << bits) if x & sb else x
return x >> y

return tf_binop(
eval_expr_tf(expr.left, memo),
eval_expr_tf(expr.right, memo),
_asr,
expr.size,
)
if op in CMP_MAP:
return tf_bool_to_int(eval_condition_expr_tf(expr, memo))
if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
ptrs = eval_expr_tf(expr.src, memo)
pt, pf = as_tf_pair(ptrs)
if pt is None or pf is None:
return None
try:
return make_tf(read_u(pt, expr.size), read_u(pf, expr.size))
except:
return None
if op == Op.MLIL_ZX:
return tf_map(eval_expr_tf(expr.src, memo), lambda v: v & mask(expr.size))
if op == Op.MLIL_SX:
return tf_map(
eval_expr_tf(expr.src, memo), lambda v: sext(v, expr.src.size, expr.size)
)
if op == Op.MLIL_LOW_PART:
return tf_map(eval_expr_tf(expr.src, memo), lambda v: v & mask(expr.size))
return None


memo = {}
result_raw = eval_expr_tf(jump_insn_ssa.dest, memo)
if isinstance(result_raw, int):
result = {"t": result_raw, "f": result_raw}
elif is_tf(result_raw):
result = {"t": result_raw.get("t"), "f": result_raw.get("f")}
else:
raise Exception(f"unexpected result type: {result_raw!r}")

resolved_map = {
JMP_ADDR: resolved_data(main_cond_key, main_cond_addr, result["t"], result["f"])
}


def apply_user_indirect_branches(func, resolved_map):
for jmp_addr, item in resolved_map.items():
try:
func.set_user_indirect_branches(
jmp_addr,
[(func.arch, item.trueAddr), (func.arch, item.falseAddr)],
func.arch,
)
except Exception as e:
print(f"[!] failed {hex(jmp_addr)}: {e}")


apply_user_indirect_branches(f, resolved_map)
bv.update_analysis_and_wait()


def build_addr_to_non_ssa_if_condition(old_mlil):
return {
insn.address: insn.condition
for bb in old_mlil.basic_blocks
for insn in bb
if insn.operation == Op.MLIL_IF
}


def convert_jump_to_if(func, resolved_map):
old_mlil = func.mlil
cond_by_addr = build_addr_to_non_ssa_if_condition(old_mlil)
new_func = MediumLevelILFunction(func.arch, low_level_il=func.llil)
new_func.prepare_to_copy_function(old_mlil)
for old_block in old_mlil.basic_blocks:
new_func.prepare_to_copy_block(old_block)
for instr_idx in range(old_block.start, old_block.end):
instr = old_mlil[instr_idx]
item = resolved_map.get(instr.address)
if item is None:
new_func.append(instr.copy_to(new_func), instr.source_location)
continue
try:
cond_expr = cond_by_addr.get(item.cond_addr)
if cond_expr is None:
raise RuntimeError(f"cannot find IF cond @ {hex(item.cond_addr)}")
indirect_branches = func.get_indirect_branches_at(instr.address)
if len(indirect_branches) != 2:
raise RuntimeError(f"indirect branches !=2 @ {hex(instr.address)}")
label_t, label_f = MediumLevelILLabel(), MediumLevelILLabel()
for branch in indirect_branches:
if branch.dest_addr == item.trueAddr:
label_t.operand = instr.targets[branch.dest_addr]
if branch.dest_addr == item.falseAddr:
label_f.operand = instr.targets[branch.dest_addr]
if (
getattr(label_t, "operand", None) is None
or getattr(label_f, "operand", None) is None
):
raise RuntimeError(f"cannot bind labels @ {hex(instr.address)}")
new_func.append(
new_func.if_expr(
cond_expr.copy_to(new_func),
label_t,
label_f,
instr.source_location,
),
instr.source_location,
)
except Exception as e:
print(f"[!] rewrite failed @ {hex(instr.address)}: {e}")
new_func.append(instr.copy_to(new_func), instr.source_location)
new_func.finalize()
new_func.generate_ssa_form()
return new_func


new_mlil = convert_jump_to_if(f, resolved_map)
print("[+] MLIL rewritten")


def dump_result(main_cond_key, main_cond_addr, result):
print("[+] main condition group:", main_cond_key, "count=", len(cond_groups[main_cond_key]))
print(f"[+] main_cond_key : {main_cond_key}")
print(f"[+] main_cond_addr: {hex(main_cond_addr)}")
cond_expr = None
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.address == main_cond_addr:
cond_expr = insn.condition
break
if cond_expr:
break

print(f"[+] cond_expr : {cond_expr}")

print(f"[+] TRUE -> {hex(result['t'])}")
print(f"[+] FALSE -> {hex(result['f'])}")

dump_result(main_cond_key, main_cond_addr, result)

运行效果

运行一次脚本后

从汇编层就可以看出来,它已经成功跳到了计算出来的目标地址,但是下面又出现了新的未解析出来的间接跳转,而且 tag 也变了(所以上面部分才写的两个 tag)

mlil 层

hlil 层

再运行一次脚本

这一次就把剩下的间接跳转也解决了

同时还发现这是一个控制流平坦化的函数,并且由多个 dispatcher 状态机互相嵌套

workflow

简介

除了像上面多次运行脚本可以达到去混淆的效果,这里考虑到另一种更自动化的方式实现——写一个定制的 bn workflow 插件

Workflow 是 BN 的分析流水线,由多个 Activity 组成,并按照依赖关系(DAG)执行,将二进制代码逐步从汇编提升为 LLILMLILHLIL 等不同层次的中间表示

Activity 就是一个具体的分析步骤,本质上是一段可以插入到 BN 分析流程中的代码,用于对 IL 或分析状态进行处理或修改

BN 官方也提供了相关 api 让用户能够根据自己的需求定制 Workflow

面对这个样本,整体思路其实和上面的思路差不多:准备数据 -> 求t/f值 -> 修改 mlil -> 注册 workflow

实现

对于和上面的实现相同的地方就不再赘述了,这里主要说一下实现得不太一样或者进行了优化的部分

0x1

1
2
3
4
5
6
7
8
9
10
class solve_two_branch_jump_handler:
@dataclass
class resolved_data:
cond: ExpressionIndex
trueAddr: int
falseAddr: int

def __init__(self, ctx: AnalysisContext):
self.ctx = ctx
self.resolved: dict[int, solve_two_branch_jump_handler.resolved_data] = {}

将执行状态都封装在了一个 solve_two_branch_jump_handler 类里,self.ctx 保存当前 workflow 上下文,self.resolved 保存每个 unresolved jump 的求解结果,这样就可以做到先批量解多个 unresolved jump,再统一重写 mlil

0x2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
for item in func.unresolved_indirect_branches:
unsolved_addr = item[1]
try:
jump = mlil[mlil.get_instruction_start(unsolved_addr)].ssa_form
if not isinstance(jump, MediumLevelILJump) and not isinstance(jump, MediumLevelILJumpTo):
raise RuntimeError("not jump {}".format(hex(jump.address)))

memo = {}
result_raw = self.eval_expr_tf(jump.dest, memo, ssa_defs, main_cond_key, bv)

if isinstance(result_raw, int):
res = {"t": result_raw, "f": result_raw}
elif self.is_tf(result_raw):
res = {"t": result_raw.get("t"), "f": result_raw.get("f")}
else:
raise RuntimeError("calc fail {}".format(hex(jump.address)))

if res["t"] is None or res["f"] is None:
raise RuntimeError("calc result contains None {}".format(hex(jump.address)))

if isinstance(jump, MediumLevelILJump) or self.check_manual_update(func, unsolved_addr, res):
print(hex(unsolved_addr), jump, func.mlil.get_expr(main_cond_expr_index), res)
func.set_user_indirect_branches(
unsolved_addr,
[(func.arch, res["t"]), (func.arch, res["f"])],
func.arch,
)

self.resolved[unsolved_addr] = solve_two_branch_jump_handler.resolved_data(
cond=main_cond_expr_index,
trueAddr=res["t"],
falseAddr=res["f"],
)

把扫描指定 tag 获取跳转地址改为查找 unresolved_indirect_branches 的方式,而且从之前一次只处理一处跳转改成了一次性可以获取所有未解析的间接跳转地址并逐个求解

0x3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def build_ssa_defs(self, mlil_ssa: MediumLevelILFunction):
ssa_defs = {}
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_SET_VAR_SSA:
ssa_defs[insn.dest] = insn.src
elif insn.operation == Op.MLIL_VAR_PHI:
ssa_defs[insn.dest] = insn
return ssa_defs

def choose_main_cond(self, mlil_ssa: MediumLevelILFunction):
ssa_defs = self.build_ssa_defs(mlil_ssa)

cond_groups = {}
cond_first_addr = {}

for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_IF:
k = self.expr_key(insn.condition, ssa_defs)
cond_groups.setdefault(k, []).append(insn.condition)
cond_first_addr.setdefault(k, insn.address)

if not cond_groups:
raise RuntimeError("no MLIL_IF found")

main_cond_key = max(cond_groups, key=lambda k: len(cond_groups[k]))
main_cond_addr = cond_first_addr[main_cond_key]

main_cond_expr_index = None
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.address == main_cond_addr and insn.operation == Op.MLIL_IF:
main_cond_expr_index = insn.condition.non_ssa_form.expr_index
break
if main_cond_expr_index is not None:
break

if main_cond_expr_index is None:
raise RuntimeError(f"cannot locate non-ssa expr for main cond @ {hex(main_cond_addr)}")

return main_cond_key, main_cond_addr, main_cond_expr_index, ssa_defs

把建立 SSA 变量到其定义表达式的映射关系和选择 main_cond 的过程封装成了函数,方便复用

0x4

修改了 resolved_data 的结构,原先的结构在重建 mlil 时还需要新建一个映射来获取 non_ssa_mlil 层的条件表达式

现在换成直接记录 cond 的索引的形式,后续使用的时候就不用再进行麻烦的转换

1
2
3
4
5
@dataclass
class resolved_data:
cond: ExpressionIndex
trueAddr: int
falseAddr: int

在选择 main_cond 时:

1
main_cond_expr_index = insn.condition.non_ssa_form.expr_index

在重建 mlil 时:

1
old_mlil.get_expr(self.resolved[instr.address].cond).copy_to(new_func)

0x5

由于每调用一次 set_user_indirect_branches 都会触发 BN 重新分析,写成 workflow 形式之后很容易触发死循环,所以加上了检查的逻辑

1
2
3
4
5
6
7
8
9
def check_manual_update(self, func, addr, res):
auto_branch = func.get_indirect_branches_at(addr)
if len(auto_branch) != 2:
return False
if auto_branch[0].dest_addr != res["t"] and auto_branch[0].dest_addr != res["f"]:
return True
if auto_branch[1].dest_addr != res["t"] and auto_branch[1].dest_addr != res["f"]:
return True
return False

调用时:

1
2
if isinstance(jump, MediumLevelILJump) or self.check_manual_update(func, unsolved_addr, res):
func.set_user_indirect_branches(...)

如果是 MLILJump,直接补 user branches;如果是 MLILJumpTo,说明 BN 已经有一份目标跳转地址了,就先看看它现有的 indirect branches 和算出来的结果是否一致,如果不一致就用算出来的结果覆盖 BN 当前的分支信息

0x6

增加了地址合法性检查,防止把不合理的地址写回分析

1
2
3
4
5
def check_addr_vaild(self, addr, bv):
for seg in bv.segments:
if seg.executable and seg.start <= addr <= seg.end and addr % 4 == 0:
return True
return False

重写前先检查

1
2
3
resolved is None
or not self.check_addr_vaild(resolved.trueAddr, ctx.function.view)
or not self.check_addr_vaild(resolved.falseAddr, ctx.function.view)

如果当前 jump 没有有效解,就打一个 bug tag,便于后续人工分析

0x7

1
2
3
4
wf = Workflow("").clone("solve_two_branch_jump")
wf.register_activity(Activity(...))
wf.insert("core.function.generateHighLevelIL", ["solve_two_branch_jump.activity"])
wf.register()

注册了一个新的 workflow activity 并且在生成 hlil 前的函数分析阶段插入

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
import json
import operator
from dataclasses import dataclass

from binaryninja import (
Workflow,
AnalysisContext,
Activity,
MediumLevelILFunction,
MediumLevelILConstPtr,
MediumLevelILConstData,
MediumLevelILConst,
MediumLevelILJump,
MediumLevelILJumpTo,
MediumLevelILVar,
MediumLevelILLabel,
RegisterValueType,
BinaryView,
Function,
ExpressionIndex,
)
from binaryninja.mediumlevelil import MediumLevelILOperation as Op


mlil_const = MediumLevelILConstPtr | MediumLevelILConstData | MediumLevelILConst

MAX_DEPTH = 16


class solve_two_branch_jump_handler:
@dataclass
class resolved_data:
cond: ExpressionIndex
trueAddr: int
falseAddr: int

def __init__(self, ctx: AnalysisContext):
self.ctx = ctx
self.resolved: dict[int, solve_two_branch_jump_handler.resolved_data] = {}

def mask(self, size: int) -> int:
return (1 << (size * 8)) - 1

def sext(self, v: int, src_size: int, dst_size: int) -> int:
bits = src_size * 8
s = 1 << (bits - 1)
v &= (1 << bits) - 1
return (v - (1 << bits) if v & s else v) & self.mask(dst_size)

def read_u(self, bv: BinaryView, addr: int, size: int) -> int:
data = bv.read(addr, size)
if data is None or len(data) != size:
raise RuntimeError(f"read failed @ {hex(addr)} size={size}")
return int.from_bytes(data, "little")

def is_tf(self, v):
return isinstance(v, dict) and ("t" in v or "f" in v)

def make_tf(self, t, f):
return None if t is None or f is None else t if t == f else {"t": t, "f": f}

def as_tf_pair(self, v):
return (v["t"], v["f"]) if self.is_tf(v) else (v, v)

def tf_pick_branch(self, v, b):
return v["t"] if self.is_tf(v) and b else v["f"] if self.is_tf(v) else v

def tf_map(self, v, fn):
vt, vf = self.as_tf_pair(v)
return None if vt is None or vf is None else self.make_tf(fn(vt), fn(vf))

def tf_binop(self, a, b, fn, size):
at, af = self.as_tf_pair(a)
bt, bf = self.as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return self.make_tf(fn(at, bt) & self.mask(size), fn(af, bf) & self.mask(size))

def tf_cmp(self, a, b, pred):
at, af = self.as_tf_pair(a)
bt, bf = self.as_tf_pair(b)
if None in (at, af, bt, bf):
return None
return self.make_tf(bool(pred(at, bt)), bool(pred(af, bf)))

def tf_bool_to_int(self, v):
if v is None:
return None
if isinstance(v, bool):
return 1 if v else 0
if self.is_tf(v):
return self.make_tf(1 if v["t"] else 0, 1 if v["f"] else 0)
return None

def _safe_const(self, expr):
return getattr(expr, "constant", None)

def build_ssa_defs(self, mlil_ssa: MediumLevelILFunction):
ssa_defs = {}
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_SET_VAR_SSA:
ssa_defs[insn.dest] = insn.src
elif insn.operation == Op.MLIL_VAR_PHI:
ssa_defs[insn.dest] = insn
return ssa_defs

def var_key(self, v, ssa_defs, memo=None, visiting=None, depth=0):
memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting
k0 = ("varssa", str(v))

if k0 in memo:
return memo[k0]
if depth > MAX_DEPTH:
return ("varssa-depth", str(v))
if str(v) in visiting:
return ("varssa-rec", str(v))

node = ssa_defs.get(v)
if node is None:
return ("varssa-undef", str(v))

visiting.add(str(v))
try:
if getattr(node, "operation", None) == Op.MLIL_VAR_PHI:
src_keys = tuple(
sorted(str(self.var_key(s, ssa_defs, memo, visiting, depth + 1)) for s in node.src)
)
out = ("phi", src_keys)
else:
out = self.expr_key(node, ssa_defs, memo, visiting, depth + 1)
memo[k0] = out
return out
finally:
visiting.discard(str(v))

def expr_key(self, expr, ssa_defs, memo=None, visiting=None, depth=0):
if expr is None:
return None

memo = {} if memo is None else memo
visiting = set() if visiting is None else visiting

if depth > MAX_DEPTH:
return ("expr-depth", str(expr))

UNARY_KEY_OPS = {
Op.MLIL_ZX: "zx",
Op.MLIL_SX: "sx",
Op.MLIL_LOW_PART: "low_part",
}
BINARY_KEY_OPS = {
Op.MLIL_ADD,
Op.MLIL_SUB,
Op.MLIL_AND,
Op.MLIL_OR,
Op.MLIL_XOR,
Op.MLIL_LSL,
Op.MLIL_LSR,
Op.MLIL_ASR,
Op.MLIL_MUL,
}
CMP_KEY_OPS = {
Op.MLIL_CMP_E,
Op.MLIL_CMP_NE,
Op.MLIL_CMP_SLT,
Op.MLIL_CMP_ULT,
Op.MLIL_CMP_SLE,
Op.MLIL_CMP_ULE,
Op.MLIL_CMP_SGE,
Op.MLIL_CMP_UGE,
Op.MLIL_CMP_SGT,
Op.MLIL_CMP_UGT,
}

op = expr.operation

if op in CMP_KEY_OPS:
return (
"cmp",
op,
self.expr_key(expr.left, ssa_defs, memo, visiting, depth + 1),
self.expr_key(expr.right, ssa_defs, memo, visiting, depth + 1),
expr.size,
)

if op == Op.MLIL_VAR_SSA:
return self.var_key(expr.src, ssa_defs, memo, visiting, depth + 1)

if op == Op.MLIL_VAR:
return ("var", str(expr.src), expr.size)

if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return ("const", self._safe_const(expr), expr.size)

if op in UNARY_KEY_OPS:
return (
UNARY_KEY_OPS[op],
self.expr_key(expr.src, ssa_defs, memo, visiting, depth + 1),
expr.size,
)

if op in BINARY_KEY_OPS:
return (
"binop",
op,
self.expr_key(expr.left, ssa_defs, memo, visiting, depth + 1),
self.expr_key(expr.right, ssa_defs, memo, visiting, depth + 1),
expr.size,
)

if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
return ("load", self.expr_key(expr.src, ssa_defs, memo, visiting, depth + 1), expr.size)

return (op, str(expr), getattr(expr, "size", None))

def choose_main_cond(self, mlil_ssa: MediumLevelILFunction):
ssa_defs = self.build_ssa_defs(mlil_ssa)

cond_groups = {}
cond_first_addr = {}

for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.operation == Op.MLIL_IF:
k = self.expr_key(insn.condition, ssa_defs)
cond_groups.setdefault(k, []).append(insn.condition)
cond_first_addr.setdefault(k, insn.address)

if not cond_groups:
raise RuntimeError("no MLIL_IF found")

main_cond_key = max(cond_groups, key=lambda k: len(cond_groups[k]))
main_cond_addr = cond_first_addr[main_cond_key]

main_cond_expr_index = None
for bb in mlil_ssa.basic_blocks:
for insn in bb:
if insn.address == main_cond_addr and insn.operation == Op.MLIL_IF:
main_cond_expr_index = insn.condition.non_ssa_form.expr_index
break
if main_cond_expr_index is not None:
break

if main_cond_expr_index is None:
raise RuntimeError(f"cannot locate non-ssa expr for main cond @ {hex(main_cond_addr)}")

return main_cond_key, main_cond_addr, main_cond_expr_index, ssa_defs

def eval_var_tf(self, v, memo, ssa_defs, main_cond_key, bv):
k = ("var", str(v))
if k in memo:
return memo[k]

node = ssa_defs.get(v)
if node is None:
return None

if getattr(node, "operation", None) == Op.MLIL_VAR_PHI and len(node.src) == 2:
tv = self.eval_var_tf(node.src[0], memo, ssa_defs, main_cond_key, bv)
fv = self.eval_var_tf(node.src[1], memo, ssa_defs, main_cond_key, bv)
out = (
None
if tv is None or fv is None
else self.make_tf(self.tf_pick_branch(tv, True), self.tf_pick_branch(fv, False))
)
else:
out = self.eval_expr_tf(node, memo, ssa_defs, main_cond_key, bv)

memo[k] = out
return out

def eval_condition_expr_tf(self, expr, memo, ssa_defs, main_cond_key, bv):
CMP_MAP = {
Op.MLIL_CMP_E: operator.eq,
Op.MLIL_CMP_NE: operator.ne,
Op.MLIL_CMP_ULT: operator.lt,
Op.MLIL_CMP_ULE: operator.le,
Op.MLIL_CMP_UGT: operator.gt,
Op.MLIL_CMP_UGE: operator.ge,
}

if self.expr_key(expr, ssa_defs) == main_cond_key:
return {"t": True, "f": False}

if expr.operation in CMP_MAP:
a = self.eval_expr_tf(expr.left, memo, ssa_defs, main_cond_key, bv)
b = self.eval_expr_tf(expr.right, memo, ssa_defs, main_cond_key, bv)
return self.tf_cmp(a, b, CMP_MAP[expr.operation])

return None

def eval_expr_tf(self, expr, memo, ssa_defs, main_cond_key, bv):
if expr is None:
return None

CMP_MAP = {
Op.MLIL_CMP_E: operator.eq,
Op.MLIL_CMP_NE: operator.ne,
Op.MLIL_CMP_ULT: operator.lt,
Op.MLIL_CMP_ULE: operator.le,
Op.MLIL_CMP_UGT: operator.gt,
Op.MLIL_CMP_UGE: operator.ge,
}
BINOP_MAP = {
Op.MLIL_ADD: operator.add,
Op.MLIL_SUB: operator.sub,
Op.MLIL_AND: operator.and_,
Op.MLIL_OR: operator.or_,
Op.MLIL_XOR: operator.xor,
Op.MLIL_LSL: operator.lshift,
Op.MLIL_LSR: operator.rshift,
Op.MLIL_MUL: operator.mul,
}

op = expr.operation

if op in (Op.MLIL_CONST, Op.MLIL_CONST_PTR, Op.MLIL_EXTERN_PTR):
return expr.constant & self.mask(expr.size)

if op == Op.MLIL_VAR_SSA:
return self.eval_var_tf(expr.src, memo, ssa_defs, main_cond_key, bv)

if op in BINOP_MAP:
return self.tf_binop(
self.eval_expr_tf(expr.left, memo, ssa_defs, main_cond_key, bv),
self.eval_expr_tf(expr.right, memo, ssa_defs, main_cond_key, bv),
BINOP_MAP[op],
expr.size,
)

if op == Op.MLIL_ASR:
def _asr(x, y):
bits = expr.size * 8
sb = 1 << (bits - 1)
x &= (1 << bits) - 1
x -= (1 << bits) if x & sb else x
return x >> y

return self.tf_binop(
self.eval_expr_tf(expr.left, memo, ssa_defs, main_cond_key, bv),
self.eval_expr_tf(expr.right, memo, ssa_defs, main_cond_key, bv),
_asr,
expr.size,
)

if op in CMP_MAP:
return self.tf_bool_to_int(
self.eval_condition_expr_tf(expr, memo, ssa_defs, main_cond_key, bv)
)

if op in (Op.MLIL_LOAD, Op.MLIL_LOAD_SSA):
ptrs = self.eval_expr_tf(expr.src, memo, ssa_defs, main_cond_key, bv)
pt, pf = self.as_tf_pair(ptrs)
if pt is None or pf is None:
return None
try:
return self.make_tf(self.read_u(bv, pt, expr.size), self.read_u(bv, pf, expr.size))
except Exception:
return None

if op == Op.MLIL_ZX:
return self.tf_map(
self.eval_expr_tf(expr.src, memo, ssa_defs, main_cond_key, bv),
lambda v: v & self.mask(expr.size),
)

if op == Op.MLIL_SX:
return self.tf_map(
self.eval_expr_tf(expr.src, memo, ssa_defs, main_cond_key, bv),
lambda v: self.sext(v, expr.src.size, expr.size),
)

if op == Op.MLIL_LOW_PART:
return self.tf_map(
self.eval_expr_tf(expr.src, memo, ssa_defs, main_cond_key, bv),
lambda v: v & self.mask(expr.size),
)

return None

def check_manual_update(self, func: Function, addr: int, res: dict):
auto_branch = func.get_indirect_branches_at(addr)
print(auto_branch, res)
if len(auto_branch) != 2:
return False
if auto_branch[0].dest_addr != res["t"] and auto_branch[0].dest_addr != res["f"]:
return True
if auto_branch[1].dest_addr != res["t"] and auto_branch[1].dest_addr != res["f"]:
return True
return False

def check_addr_vaild(self, addr: int, bv: BinaryView):
for seg in bv.segments:
if seg.executable and seg.start <= addr <= seg.end and addr % 4 == 0:
return True
return False

def connect_basic_block(self, func: Function):
mlil = func.mlil
mlil_ssa = mlil.ssa_form

bv = func.view
if mlil is None or mlil_ssa is None or bv is None:
return

main_cond_key, main_cond_addr, main_cond_expr_index, ssa_defs = self.choose_main_cond(mlil_ssa)

for item in func.unresolved_indirect_branches:
unsolved_addr = item[1]
try:
jump = mlil[mlil.get_instruction_start(unsolved_addr)].ssa_form
if not isinstance(jump, MediumLevelILJump) and not isinstance(jump, MediumLevelILJumpTo):
raise RuntimeError("not jump {}".format(hex(jump.address)))

memo = {}
result_raw = self.eval_expr_tf(jump.dest, memo, ssa_defs, main_cond_key, bv)

if isinstance(result_raw, int):
res = {"t": result_raw, "f": result_raw}
elif self.is_tf(result_raw):
res = {"t": result_raw.get("t"), "f": result_raw.get("f")}
else:
raise RuntimeError("calc fail {}".format(hex(jump.address)))

if res["t"] is None or res["f"] is None:
raise RuntimeError("calc result contains None {}".format(hex(jump.address)))

if isinstance(jump, MediumLevelILJump) or self.check_manual_update(func, unsolved_addr, res):
print(hex(unsolved_addr), jump, func.mlil.get_expr(main_cond_expr_index), res)
func.set_user_indirect_branches(
unsolved_addr,
[(func.arch, res["t"]), (func.arch, res["f"])],
func.arch,
)

self.resolved[unsolved_addr] = solve_two_branch_jump_handler.resolved_data(
cond=main_cond_expr_index,
trueAddr=res["t"],
falseAddr=res["f"],
)

except Exception as e:
print(hex(unsolved_addr), e)
continue

def convert_jump_to_if(self, ctx: AnalysisContext):
new_func = MediumLevelILFunction(ctx.function.arch, low_level_il=ctx.llil)
old_mlil = ctx.function.mlil
new_func.prepare_to_copy_function(old_mlil)

for old_block in old_mlil:
new_func.prepare_to_copy_block(old_block)
for instr_idx in range(old_block.start, old_block.end):
instr = old_mlil[instr_idx]

if (
isinstance(instr, MediumLevelILJumpTo)
and isinstance(instr.dest, MediumLevelILVar)
and not instr.get_possible_reg_values(instr.dest.var.storage).type == RegisterValueType.ConstantValue
):
for tag in old_mlil.source_function.get_tags_at(instr.address, auto=True):
old_mlil.source_function.remove_auto_address_tags_of_type(instr.address, tag.type.name)

cond = {"value": None}
try:
resolved = self.resolved.get(instr.address)
if (
resolved is None
or not self.check_addr_vaild(resolved.trueAddr, ctx.function.view)
or not self.check_addr_vaild(resolved.falseAddr, ctx.function.view)
):
old_mlil.source_function.add_tag("Bugs", "need manual analyze jump", instr.address, auto=True)
print("add bug tag at {}".format(hex(instr.address)))
except Exception as e:
old_mlil.source_function.add_tag("Bugs", "need manual analyze jump", instr.address, auto=True)
print("add bug tag at {}".format(hex(instr.address)))
print(e)

if isinstance(instr, MediumLevelILJumpTo) and instr.address in self.resolved:
try:
label_t = MediumLevelILLabel()
label_f = MediumLevelILLabel()
indirect_branches = ctx.function.get_indirect_branches_at(instr.address)
if len(indirect_branches) != 2:
raise RuntimeError(
"indirect branches len!=2 {} {}".format(hex(instr.address), indirect_branches)
)

for branch in indirect_branches:
if branch.dest_addr == self.resolved[instr.address].trueAddr:
label_t.operand = instr.targets[branch.dest_addr]
if branch.dest_addr == self.resolved[instr.address].falseAddr:
label_f.operand = instr.targets[branch.dest_addr]

if_expr = new_func.if_expr(
old_mlil.get_expr(self.resolved[instr.address].cond).copy_to(new_func),
label_t,
label_f,
old_mlil[instr_idx].source_location,
)
new_func.append(if_expr, old_mlil[instr_idx].source_location)

except Exception as e:
print(e)
new_func.append(old_mlil[instr_idx].copy_to(new_func), old_mlil[instr_idx].source_location)
else:
new_func.append(old_mlil[instr_idx].copy_to(new_func), old_mlil[instr_idx].source_location)

new_func.finalize()
new_func.generate_ssa_form()
ctx.mlil = new_func

def run(self):
self.connect_basic_block(self.ctx.function)
self.convert_jump_to_if(self.ctx)


def install_solve_two_branch_jump_handler(ctx: AnalysisContext):
handler = solve_two_branch_jump_handler(ctx)
handler.run()


wf = Workflow("").clone("solve_two_branch_jump")
wf.register_activity(Activity(configuration=json.dumps({
"name": "solve_two_branch_jump.activity",
"title": "solve_two_branch_jump",
"description": "solve_two_branch_jump",
"eligibility": {
"auto": {
"default": True
}
}
}), action=lambda context: install_solve_two_branch_jump_handler(context)))
wf.insert("core.function.generateHighLevelIL", ["solve_two_branch_jump.activity"])
wf.register()
print("[+] workflow registered: solve_two_branch_jump")

使用

把这个脚本放进 bn 的 plugins 目录下再重启 bn,之后在混淆函数界面右键选择 Edit Function Properties就能 Apply 自定义的 workflow 了

效果和上面的一样

参考

使用bianryninja workflow解混淆0x1 - 双分支间接跳转

Binary Ninja Workflows


BN IL 层解决多条件控制的双分支间接跳转
http://example.com/2026/04/01/solve_two_branch_jump/
作者
Eleven
发布于
2026年4月1日
许可协议