blob: e89e93fd50c4a64acdf0c4cd09bccbf28ca6e2bb [file] [log] [blame]
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -08001#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Tools to interact with BPF programs."""
18
19import abc
20import collections
21import struct
22
23# This comes from syscall(2). Most architectures only support passing 6 args to
24# syscalls, but ARM supports passing 7.
25MAX_SYSCALL_ARGUMENTS = 7
26
27# The following fields were copied from <linux/bpf_common.h>:
28
29# Instruction classes
30BPF_LD = 0x00
31BPF_LDX = 0x01
32BPF_ST = 0x02
33BPF_STX = 0x03
34BPF_ALU = 0x04
35BPF_JMP = 0x05
36BPF_RET = 0x06
37BPF_MISC = 0x07
38
39# LD/LDX fields.
40# Size
41BPF_W = 0x00
42BPF_H = 0x08
43BPF_B = 0x10
44# Mode
45BPF_IMM = 0x00
46BPF_ABS = 0x20
47BPF_IND = 0x40
48BPF_MEM = 0x60
49BPF_LEN = 0x80
50BPF_MSH = 0xa0
51
52# JMP fields.
53BPF_JA = 0x00
54BPF_JEQ = 0x10
55BPF_JGT = 0x20
56BPF_JGE = 0x30
57BPF_JSET = 0x40
58
59# Source
60BPF_K = 0x00
61BPF_X = 0x08
62
63BPF_MAXINSNS = 4096
64
65# The following fields were copied from <linux/seccomp.h>:
66
67SECCOMP_RET_KILL_PROCESS = 0x80000000
68SECCOMP_RET_KILL_THREAD = 0x00000000
69SECCOMP_RET_TRAP = 0x00030000
70SECCOMP_RET_ERRNO = 0x00050000
71SECCOMP_RET_TRACE = 0x7ff00000
72SECCOMP_RET_LOG = 0x7ffc0000
73SECCOMP_RET_ALLOW = 0x7fff0000
74
Luis Hector Chavez05392b82018-10-28 21:40:10 -070075SECCOMP_RET_ACTION_FULL = 0xffff0000
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -080076SECCOMP_RET_DATA = 0x0000ffff
77
78
Luis Hector Chavez05392b82018-10-28 21:40:10 -070079def arg_offset(arg_index, hi=False):
80 """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
81 offsetof_args = 4 + 4 + 8
82 arg_width = 8
83 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
84
85
86def simulate(instructions, arch, syscall_number, *args):
87 """Simulate a BPF program with the given arguments."""
88 args = ((args + (0, ) *
89 (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
90 input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
91 syscall_number, arch, 0, *args)
92
93 register = 0
94 program_counter = 0
95 cost = 0
96 while program_counter < len(instructions):
97 ins = instructions[program_counter]
98 program_counter += 1
99 cost += 1
100 if ins.code == BPF_LD | BPF_W | BPF_ABS:
101 register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
102 elif ins.code == BPF_JMP | BPF_JA | BPF_K:
103 program_counter += ins.k
104 elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
105 if register == ins.k:
106 program_counter += ins.jt
107 else:
108 program_counter += ins.jf
109 elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
110 if register > ins.k:
111 program_counter += ins.jt
112 else:
113 program_counter += ins.jf
114 elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
115 if register >= ins.k:
116 program_counter += ins.jt
117 else:
118 program_counter += ins.jf
119 elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
120 if register & ins.k != 0:
121 program_counter += ins.jt
122 else:
123 program_counter += ins.jf
124 elif ins.code == BPF_RET:
125 if ins.k == SECCOMP_RET_KILL_PROCESS:
126 return (cost, 'KILL_PROCESS')
127 if ins.k == SECCOMP_RET_KILL_THREAD:
128 return (cost, 'KILL_THREAD')
129 if ins.k == SECCOMP_RET_TRAP:
130 return (cost, 'TRAP')
131 if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
132 return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
133 if ins.k == SECCOMP_RET_TRACE:
134 return (cost, 'TRACE')
135 if ins.k == SECCOMP_RET_LOG:
136 return (cost, 'LOG')
137 if ins.k == SECCOMP_RET_ALLOW:
138 return (cost, 'ALLOW')
139 raise Exception('unknown return %#x' % ins.k)
140 else:
141 raise Exception('unknown instruction %r' % (ins, ))
142 raise Exception('out-of-bounds')
143
144
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -0800145class SockFilter(
146 collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
147 """A representation of struct sock_filter."""
148
149 __slots__ = ()
150
151 def encode(self):
152 """Return an encoded version of the SockFilter."""
153 return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
154
155
156class AbstractBlock(abc.ABC):
157 """A class that implements the visitor pattern."""
158
159 def __init__(self):
160 super().__init__()
161
162 @abc.abstractmethod
163 def accept(self, visitor):
164 pass
165
166
167class BasicBlock(AbstractBlock):
168 """A concrete implementation of AbstractBlock that has been compiled."""
169
170 def __init__(self, instructions):
171 super().__init__()
172 self._instructions = instructions
173
174 def accept(self, visitor):
175 visitor.visit(self)
176
177 @property
178 def instructions(self):
179 return self._instructions
180
181 @property
182 def opcodes(self):
183 return b''.join(i.encode() for i in self._instructions)
184
185 def __eq__(self, o):
186 if not isinstance(o, BasicBlock):
187 return False
188 return self._instructions == o._instructions
189
190
191class KillProcess(BasicBlock):
192 """A BasicBlock that unconditionally returns KILL_PROCESS."""
193
194 def __init__(self):
195 super().__init__(
196 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
197
198
199class KillThread(BasicBlock):
200 """A BasicBlock that unconditionally returns KILL_THREAD."""
201
202 def __init__(self):
203 super().__init__(
204 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
205
206
207class Trap(BasicBlock):
208 """A BasicBlock that unconditionally returns TRAP."""
209
210 def __init__(self):
211 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
212
213
214class Trace(BasicBlock):
215 """A BasicBlock that unconditionally returns TRACE."""
216
217 def __init__(self):
218 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
219
220
221class Log(BasicBlock):
222 """A BasicBlock that unconditionally returns LOG."""
223
224 def __init__(self):
225 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
226
227
228class ReturnErrno(BasicBlock):
229 """A BasicBlock that unconditionally returns the specified errno."""
230
231 def __init__(self, errno):
232 super().__init__([
233 SockFilter(BPF_RET, 0x00, 0x00,
234 SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
235 ])
236 self.errno = errno
237
238
239class Allow(BasicBlock):
240 """A BasicBlock that unconditionally returns ALLOW."""
241
242 def __init__(self):
243 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
Luis Hector Chavez05392b82018-10-28 21:40:10 -0700244
245
246class ValidateArch(AbstractBlock):
247 """An AbstractBlock that validates the architecture."""
248
249 def __init__(self, next_block):
250 super().__init__()
251 self.next_block = next_block
252
253 def accept(self, visitor):
254 self.next_block.accept(visitor)
255 visitor.visit(self)
256
257
258class SyscallEntry(AbstractBlock):
259 """An abstract block that represents a syscall comparison in a DAG."""
260
261 def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
262 super().__init__()
263 self.op = op
264 self.syscall_number = syscall_number
265 self.jt = jt
266 self.jf = jf
267
268 def __lt__(self, o):
269 # Defined because we want to compare tuples that contain SyscallEntries.
270 return False
271
272 def __gt__(self, o):
273 # Defined because we want to compare tuples that contain SyscallEntries.
274 return False
275
276 def accept(self, visitor):
277 self.jt.accept(visitor)
278 self.jf.accept(visitor)
279 visitor.visit(self)
280
281 def __lt__(self, o):
282 # Defined because we want to compare tuples that contain SyscallEntries.
283 return False
284
285 def __gt__(self, o):
286 # Defined because we want to compare tuples that contain SyscallEntries.
287 return False
288
289
290class WideAtom(AbstractBlock):
291 """A BasicBlock that represents a 32-bit wide atom."""
292
293 def __init__(self, arg_offset, op, value, jt, jf):
294 super().__init__()
295 self.arg_offset = arg_offset
296 self.op = op
297 self.value = value
298 self.jt = jt
299 self.jf = jf
300
301 def accept(self, visitor):
302 self.jt.accept(visitor)
303 self.jf.accept(visitor)
304 visitor.visit(self)
305
306
307class Atom(AbstractBlock):
308 """A BasicBlock that represents an atom (a simple comparison operation)."""
309
310 def __init__(self, arg_index, op, value, jt, jf):
311 super().__init__()
312 if op == '==':
313 op = BPF_JEQ
314 elif op == '!=':
315 op = BPF_JEQ
316 jt, jf = jf, jt
317 elif op == '>':
318 op = BPF_JGT
319 elif op == '<=':
320 op = BPF_JGT
321 jt, jf = jf, jt
322 elif op == '>=':
323 op = BPF_JGE
324 elif op == '<':
325 op = BPF_JGE
326 jt, jf = jf, jt
327 elif op == '&':
328 op = BPF_JSET
329 elif op == 'in':
330 op = BPF_JSET
331 # The mask is negated, so the comparison will be true when the
332 # argument includes a flag that wasn't listed in the original
333 # (non-negated) mask. This would be the failure case, so we switch
334 # |jt| and |jf|.
335 value = (~value) & ((1 << 64) - 1)
336 jt, jf = jf, jt
337 else:
338 raise Exception('Unknown operator %s' % op)
339
340 self.arg_index = arg_index
341 self.op = op
342 self.jt = jt
343 self.jf = jf
344 self.value = value
345
346 def accept(self, visitor):
347 self.jt.accept(visitor)
348 self.jf.accept(visitor)
349 visitor.visit(self)
350
351
352class AbstractVisitor(abc.ABC):
353 """An abstract visitor."""
354
355 def process(self, block):
356 block.accept(self)
357 return block
358
359 def visit(self, block):
360 if isinstance(block, KillProcess):
361 self.visitKillProcess(block)
362 elif isinstance(block, KillThread):
363 self.visitKillThread(block)
364 elif isinstance(block, Trap):
365 self.visitTrap(block)
366 elif isinstance(block, ReturnErrno):
367 self.visitReturnErrno(block)
368 elif isinstance(block, Trace):
369 self.visitTrace(block)
370 elif isinstance(block, Log):
371 self.visitLog(block)
372 elif isinstance(block, Allow):
373 self.visitAllow(block)
374 elif isinstance(block, BasicBlock):
375 self.visitBasicBlock(block)
376 elif isinstance(block, ValidateArch):
377 self.visitValidateArch(block)
378 elif isinstance(block, SyscallEntry):
379 self.visitSyscallEntry(block)
380 elif isinstance(block, WideAtom):
381 self.visitWideAtom(block)
382 elif isinstance(block, Atom):
383 self.visitAtom(block)
384 else:
385 raise Exception('Unknown block type: %r' % block)
386
387 @abc.abstractmethod
388 def visitKillProcess(self, block):
389 pass
390
391 @abc.abstractmethod
392 def visitKillThread(self, block):
393 pass
394
395 @abc.abstractmethod
396 def visitTrap(self, block):
397 pass
398
399 @abc.abstractmethod
400 def visitReturnErrno(self, block):
401 pass
402
403 @abc.abstractmethod
404 def visitTrace(self, block):
405 pass
406
407 @abc.abstractmethod
408 def visitLog(self, block):
409 pass
410
411 @abc.abstractmethod
412 def visitAllow(self, block):
413 pass
414
415 @abc.abstractmethod
416 def visitBasicBlock(self, block):
417 pass
418
419 @abc.abstractmethod
420 def visitValidateArch(self, block):
421 pass
422
423 @abc.abstractmethod
424 def visitSyscallEntry(self, block):
425 pass
426
427 @abc.abstractmethod
428 def visitWideAtom(self, block):
429 pass
430
431 @abc.abstractmethod
432 def visitAtom(self, block):
433 pass
434
435
436class CopyingVisitor(AbstractVisitor):
437 """A visitor that copies Blocks."""
438
439 def __init__(self):
440 self._mapping = {}
441
442 def process(self, block):
443 self._mapping = {}
444 block.accept(self)
445 return self._mapping[id(block)]
446
447 def visitKillProcess(self, block):
448 if id(block) in self._mapping:
449 return
450 self._mapping[id(block)] = KillProcess()
451
452 def visitKillThread(self, block):
453 if id(block) in self._mapping:
454 return
455 self._mapping[id(block)] = KillThread()
456
457 def visitTrap(self, block):
458 if id(block) in self._mapping:
459 return
460 self._mapping[id(block)] = Trap()
461
462 def visitReturnErrno(self, block):
463 if id(block) in self._mapping:
464 return
465 self._mapping[id(block)] = ReturnErrno(block.errno)
466
467 def visitTrace(self, block):
468 if id(block) in self._mapping:
469 return
470 self._mapping[id(block)] = Trace()
471
472 def visitLog(self, block):
473 if id(block) in self._mapping:
474 return
475 self._mapping[id(block)] = Log()
476
477 def visitAllow(self, block):
478 if id(block) in self._mapping:
479 return
480 self._mapping[id(block)] = Allow()
481
482 def visitBasicBlock(self, block):
483 if id(block) in self._mapping:
484 return
485 self._mapping[id(block)] = BasicBlock(block.instructions)
486
487 def visitValidateArch(self, block):
488 if id(block) in self._mapping:
489 return
490 self._mapping[id(block)] = ValidateArch(
491 block.arch, self._mapping[id(block.next_block)])
492
493 def visitSyscallEntry(self, block):
494 if id(block) in self._mapping:
495 return
496 self._mapping[id(block)] = SyscallEntry(
497 block.syscall_number,
498 self._mapping[id(block.jt)],
499 self._mapping[id(block.jf)],
500 op=block.op)
501
502 def visitWideAtom(self, block):
503 if id(block) in self._mapping:
504 return
505 self._mapping[id(block)] = WideAtom(
506 block.arg_offset, block.op, block.value, self._mapping[id(
507 block.jt)], self._mapping[id(block.jf)])
508
509 def visitAtom(self, block):
510 if id(block) in self._mapping:
511 return
512 self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
513 self._mapping[id(block.jt)],
514 self._mapping[id(block.jf)])
515
516
517class LoweringVisitor(CopyingVisitor):
518 """A visitor that lowers Atoms into WideAtoms."""
519
520 def __init__(self, *, arch):
521 super().__init__()
522 self._bits = arch.bits
523
524 def visitAtom(self, block):
525 if id(block) in self._mapping:
526 return
527
528 lo = block.value & 0xFFFFFFFF
529 hi = (block.value >> 32) & 0xFFFFFFFF
530
531 lo_block = WideAtom(
532 arg_offset(block.arg_index, False), block.op, lo,
533 self._mapping[id(block.jt)], self._mapping[id(block.jf)])
534
535 if self._bits == 32:
536 self._mapping[id(block)] = lo_block
537 return
538
539 if block.op in (BPF_JGE, BPF_JGT):
540 # hi_1,lo_1 <op> hi_2,lo_2
541 #
542 # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
543 if hi == 0:
544 # Special case: it's not needed to check whether |hi_1 == hi_2|,
545 # because it's true iff the JGT test fails.
546 self._mapping[id(block)] = WideAtom(
547 arg_offset(block.arg_index, True), BPF_JGT, hi,
548 self._mapping[id(block.jt)], lo_block)
549 return
550 hi_eq_block = WideAtom(
551 arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
552 self._mapping[id(block.jf)])
553 self._mapping[id(block)] = WideAtom(
554 arg_offset(block.arg_index, True), BPF_JGT, hi,
555 self._mapping[id(block.jt)], hi_eq_block)
556 return
557 if block.op == BPF_JSET:
558 # hi_1,lo_1 & hi_2,lo_2
559 #
560 # hi_1 & hi_2 || lo_1 & lo_2
561 if hi == 0:
562 # Special case: |hi_1 & hi_2| will never be True, so jump
563 # directly into the |lo_1 & lo_2| case.
564 self._mapping[id(block)] = lo_block
565 return
566 self._mapping[id(block)] = WideAtom(
567 arg_offset(block.arg_index, True), block.op, hi,
568 self._mapping[id(block.jt)], lo_block)
569 return
570
571 assert block.op == BPF_JEQ, block.op
572
573 # hi_1,lo_1 == hi_2,lo_2
574 #
575 # hi_1 == hi_2 && lo_1 == lo_2
576 self._mapping[id(block)] = WideAtom(
577 arg_offset(block.arg_index, True), block.op, hi, lo_block,
578 self._mapping[id(block.jf)])
579
580
581class FlatteningVisitor:
582 """A visitor that flattens a DAG of Block objects."""
583
584 def __init__(self, *, arch, kill_action):
585 self._kill_action = kill_action
586 self._instructions = []
587 self._arch = arch
588 self._offsets = {}
589
590 @property
591 def result(self):
592 return BasicBlock(self._instructions)
593
594 def _distance(self, block):
595 distance = self._offsets[id(block)] + len(self._instructions)
596 assert distance >= 0
597 return distance
598
599 def _emit_load_arg(self, offset):
600 return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
601
602 def _emit_jmp(self, op, value, jt_distance, jf_distance):
603 if jt_distance < 0x100 and jf_distance < 0x100:
604 return [
605 SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
606 value),
607 ]
608 if jt_distance + 1 < 0x100:
609 return [
610 SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
611 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
612 ]
613 if jf_distance + 1 < 0x100:
614 return [
615 SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
616 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
617 ]
618 return [
619 SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
620 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
621 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
622 ]
623
624 def visit(self, block):
625 if id(block) in self._offsets:
626 return
627
628 if isinstance(block, BasicBlock):
629 instructions = block.instructions
630 elif isinstance(block, ValidateArch):
631 instructions = [
632 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
633 SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
634 self._distance(block.next_block) + 1, 0,
635 self._arch.arch_nr),
636 ] + self._kill_action.instructions + [
637 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
638 ]
639 elif isinstance(block, SyscallEntry):
640 instructions = self._emit_jmp(block.op, block.syscall_number,
641 self._distance(block.jt),
642 self._distance(block.jf))
643 elif isinstance(block, WideAtom):
644 instructions = (
645 self._emit_load_arg(block.arg_offset) + self._emit_jmp(
646 block.op, block.value, self._distance(block.jt),
647 self._distance(block.jf)))
648 else:
649 raise Exception('Unknown block type: %r' % block)
650
651 self._instructions = instructions + self._instructions
652 self._offsets[id(block)] = -len(self._instructions)
653 return