blob: bd7007e0923c8ccdcb1d7d8b0e777a4b3aeca7d0 [file] [log] [blame]
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -08001#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Tools to interact with BPF programs."""
18
19import abc
20import collections
21import struct
22
23# This comes from syscall(2). Most architectures only support passing 6 args to
24# syscalls, but ARM supports passing 7.
25MAX_SYSCALL_ARGUMENTS = 7
26
27# The following fields were copied from <linux/bpf_common.h>:
28
29# Instruction classes
30BPF_LD = 0x00
31BPF_LDX = 0x01
32BPF_ST = 0x02
33BPF_STX = 0x03
34BPF_ALU = 0x04
35BPF_JMP = 0x05
36BPF_RET = 0x06
37BPF_MISC = 0x07
38
39# LD/LDX fields.
40# Size
41BPF_W = 0x00
42BPF_H = 0x08
43BPF_B = 0x10
44# Mode
45BPF_IMM = 0x00
46BPF_ABS = 0x20
47BPF_IND = 0x40
48BPF_MEM = 0x60
49BPF_LEN = 0x80
50BPF_MSH = 0xa0
51
52# JMP fields.
53BPF_JA = 0x00
54BPF_JEQ = 0x10
55BPF_JGT = 0x20
56BPF_JGE = 0x30
57BPF_JSET = 0x40
58
59# Source
60BPF_K = 0x00
61BPF_X = 0x08
62
63BPF_MAXINSNS = 4096
64
65# The following fields were copied from <linux/seccomp.h>:
66
67SECCOMP_RET_KILL_PROCESS = 0x80000000
68SECCOMP_RET_KILL_THREAD = 0x00000000
69SECCOMP_RET_TRAP = 0x00030000
70SECCOMP_RET_ERRNO = 0x00050000
71SECCOMP_RET_TRACE = 0x7ff00000
72SECCOMP_RET_LOG = 0x7ffc0000
73SECCOMP_RET_ALLOW = 0x7fff0000
74
Luis Hector Chavez05392b82018-10-28 21:40:10 -070075SECCOMP_RET_ACTION_FULL = 0xffff0000
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -080076SECCOMP_RET_DATA = 0x0000ffff
77
78
Luis Hector Chavez05392b82018-10-28 21:40:10 -070079def arg_offset(arg_index, hi=False):
80 """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
81 offsetof_args = 4 + 4 + 8
82 arg_width = 8
83 return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
84
85
86def simulate(instructions, arch, syscall_number, *args):
87 """Simulate a BPF program with the given arguments."""
88 args = ((args + (0, ) *
89 (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
90 input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
91 syscall_number, arch, 0, *args)
92
93 register = 0
94 program_counter = 0
95 cost = 0
96 while program_counter < len(instructions):
97 ins = instructions[program_counter]
98 program_counter += 1
99 cost += 1
100 if ins.code == BPF_LD | BPF_W | BPF_ABS:
101 register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
102 elif ins.code == BPF_JMP | BPF_JA | BPF_K:
103 program_counter += ins.k
104 elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
105 if register == ins.k:
106 program_counter += ins.jt
107 else:
108 program_counter += ins.jf
109 elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
110 if register > ins.k:
111 program_counter += ins.jt
112 else:
113 program_counter += ins.jf
114 elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
115 if register >= ins.k:
116 program_counter += ins.jt
117 else:
118 program_counter += ins.jf
119 elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
120 if register & ins.k != 0:
121 program_counter += ins.jt
122 else:
123 program_counter += ins.jf
124 elif ins.code == BPF_RET:
125 if ins.k == SECCOMP_RET_KILL_PROCESS:
126 return (cost, 'KILL_PROCESS')
127 if ins.k == SECCOMP_RET_KILL_THREAD:
128 return (cost, 'KILL_THREAD')
129 if ins.k == SECCOMP_RET_TRAP:
130 return (cost, 'TRAP')
131 if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
132 return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
133 if ins.k == SECCOMP_RET_TRACE:
134 return (cost, 'TRACE')
135 if ins.k == SECCOMP_RET_LOG:
136 return (cost, 'LOG')
137 if ins.k == SECCOMP_RET_ALLOW:
138 return (cost, 'ALLOW')
139 raise Exception('unknown return %#x' % ins.k)
140 else:
141 raise Exception('unknown instruction %r' % (ins, ))
142 raise Exception('out-of-bounds')
143
144
Luis Hector Chavez1ac9eca2018-12-04 21:28:52 -0800145class SockFilter(
146 collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
147 """A representation of struct sock_filter."""
148
149 __slots__ = ()
150
151 def encode(self):
152 """Return an encoded version of the SockFilter."""
153 return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
154
155
156class AbstractBlock(abc.ABC):
157 """A class that implements the visitor pattern."""
158
159 def __init__(self):
160 super().__init__()
161
162 @abc.abstractmethod
163 def accept(self, visitor):
164 pass
165
166
167class BasicBlock(AbstractBlock):
168 """A concrete implementation of AbstractBlock that has been compiled."""
169
170 def __init__(self, instructions):
171 super().__init__()
172 self._instructions = instructions
173
174 def accept(self, visitor):
175 visitor.visit(self)
176
177 @property
178 def instructions(self):
179 return self._instructions
180
181 @property
182 def opcodes(self):
183 return b''.join(i.encode() for i in self._instructions)
184
185 def __eq__(self, o):
186 if not isinstance(o, BasicBlock):
187 return False
188 return self._instructions == o._instructions
189
190
191class KillProcess(BasicBlock):
192 """A BasicBlock that unconditionally returns KILL_PROCESS."""
193
194 def __init__(self):
195 super().__init__(
196 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
197
198
199class KillThread(BasicBlock):
200 """A BasicBlock that unconditionally returns KILL_THREAD."""
201
202 def __init__(self):
203 super().__init__(
204 [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
205
206
207class Trap(BasicBlock):
208 """A BasicBlock that unconditionally returns TRAP."""
209
210 def __init__(self):
211 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
212
213
214class Trace(BasicBlock):
215 """A BasicBlock that unconditionally returns TRACE."""
216
217 def __init__(self):
218 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
219
220
221class Log(BasicBlock):
222 """A BasicBlock that unconditionally returns LOG."""
223
224 def __init__(self):
225 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
226
227
228class ReturnErrno(BasicBlock):
229 """A BasicBlock that unconditionally returns the specified errno."""
230
231 def __init__(self, errno):
232 super().__init__([
233 SockFilter(BPF_RET, 0x00, 0x00,
234 SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
235 ])
236 self.errno = errno
237
238
239class Allow(BasicBlock):
240 """A BasicBlock that unconditionally returns ALLOW."""
241
242 def __init__(self):
243 super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
Luis Hector Chavez05392b82018-10-28 21:40:10 -0700244
245
246class ValidateArch(AbstractBlock):
247 """An AbstractBlock that validates the architecture."""
248
249 def __init__(self, next_block):
250 super().__init__()
251 self.next_block = next_block
252
253 def accept(self, visitor):
254 self.next_block.accept(visitor)
255 visitor.visit(self)
256
257
258class SyscallEntry(AbstractBlock):
259 """An abstract block that represents a syscall comparison in a DAG."""
260
261 def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
262 super().__init__()
263 self.op = op
264 self.syscall_number = syscall_number
265 self.jt = jt
266 self.jf = jf
267
268 def __lt__(self, o):
269 # Defined because we want to compare tuples that contain SyscallEntries.
270 return False
271
272 def __gt__(self, o):
273 # Defined because we want to compare tuples that contain SyscallEntries.
274 return False
275
276 def accept(self, visitor):
277 self.jt.accept(visitor)
278 self.jf.accept(visitor)
279 visitor.visit(self)
280
281 def __lt__(self, o):
282 # Defined because we want to compare tuples that contain SyscallEntries.
283 return False
284
285 def __gt__(self, o):
286 # Defined because we want to compare tuples that contain SyscallEntries.
287 return False
288
289
290class WideAtom(AbstractBlock):
291 """A BasicBlock that represents a 32-bit wide atom."""
292
293 def __init__(self, arg_offset, op, value, jt, jf):
294 super().__init__()
295 self.arg_offset = arg_offset
296 self.op = op
297 self.value = value
298 self.jt = jt
299 self.jf = jf
300
301 def accept(self, visitor):
302 self.jt.accept(visitor)
303 self.jf.accept(visitor)
304 visitor.visit(self)
305
306
307class Atom(AbstractBlock):
308 """A BasicBlock that represents an atom (a simple comparison operation)."""
309
310 def __init__(self, arg_index, op, value, jt, jf):
311 super().__init__()
312 if op == '==':
313 op = BPF_JEQ
314 elif op == '!=':
315 op = BPF_JEQ
316 jt, jf = jf, jt
317 elif op == '>':
318 op = BPF_JGT
319 elif op == '<=':
320 op = BPF_JGT
321 jt, jf = jf, jt
322 elif op == '>=':
323 op = BPF_JGE
324 elif op == '<':
325 op = BPF_JGE
326 jt, jf = jf, jt
327 elif op == '&':
328 op = BPF_JSET
329 elif op == 'in':
330 op = BPF_JSET
331 # The mask is negated, so the comparison will be true when the
332 # argument includes a flag that wasn't listed in the original
333 # (non-negated) mask. This would be the failure case, so we switch
334 # |jt| and |jf|.
335 value = (~value) & ((1 << 64) - 1)
336 jt, jf = jf, jt
337 else:
338 raise Exception('Unknown operator %s' % op)
339
340 self.arg_index = arg_index
341 self.op = op
342 self.jt = jt
343 self.jf = jf
344 self.value = value
345
346 def accept(self, visitor):
347 self.jt.accept(visitor)
348 self.jf.accept(visitor)
349 visitor.visit(self)
350
351
352class AbstractVisitor(abc.ABC):
353 """An abstract visitor."""
354
355 def process(self, block):
356 block.accept(self)
357 return block
358
359 def visit(self, block):
360 if isinstance(block, KillProcess):
361 self.visitKillProcess(block)
362 elif isinstance(block, KillThread):
363 self.visitKillThread(block)
364 elif isinstance(block, Trap):
365 self.visitTrap(block)
366 elif isinstance(block, ReturnErrno):
367 self.visitReturnErrno(block)
368 elif isinstance(block, Trace):
369 self.visitTrace(block)
370 elif isinstance(block, Log):
371 self.visitLog(block)
372 elif isinstance(block, Allow):
373 self.visitAllow(block)
374 elif isinstance(block, BasicBlock):
375 self.visitBasicBlock(block)
376 elif isinstance(block, ValidateArch):
377 self.visitValidateArch(block)
378 elif isinstance(block, SyscallEntry):
379 self.visitSyscallEntry(block)
380 elif isinstance(block, WideAtom):
381 self.visitWideAtom(block)
382 elif isinstance(block, Atom):
383 self.visitAtom(block)
384 else:
385 raise Exception('Unknown block type: %r' % block)
386
387 @abc.abstractmethod
388 def visitKillProcess(self, block):
389 pass
390
391 @abc.abstractmethod
392 def visitKillThread(self, block):
393 pass
394
395 @abc.abstractmethod
396 def visitTrap(self, block):
397 pass
398
399 @abc.abstractmethod
400 def visitReturnErrno(self, block):
401 pass
402
403 @abc.abstractmethod
404 def visitTrace(self, block):
405 pass
406
407 @abc.abstractmethod
408 def visitLog(self, block):
409 pass
410
411 @abc.abstractmethod
412 def visitAllow(self, block):
413 pass
414
415 @abc.abstractmethod
416 def visitBasicBlock(self, block):
417 pass
418
419 @abc.abstractmethod
420 def visitValidateArch(self, block):
421 pass
422
423 @abc.abstractmethod
424 def visitSyscallEntry(self, block):
425 pass
426
427 @abc.abstractmethod
428 def visitWideAtom(self, block):
429 pass
430
431 @abc.abstractmethod
432 def visitAtom(self, block):
433 pass
434
435
436class CopyingVisitor(AbstractVisitor):
437 """A visitor that copies Blocks."""
438
439 def __init__(self):
440 self._mapping = {}
441
442 def process(self, block):
443 self._mapping = {}
444 block.accept(self)
445 return self._mapping[id(block)]
446
447 def visitKillProcess(self, block):
448 if id(block) in self._mapping:
449 return
450 self._mapping[id(block)] = KillProcess()
451
452 def visitKillThread(self, block):
453 if id(block) in self._mapping:
454 return
455 self._mapping[id(block)] = KillThread()
456
457 def visitTrap(self, block):
458 if id(block) in self._mapping:
459 return
460 self._mapping[id(block)] = Trap()
461
462 def visitReturnErrno(self, block):
463 if id(block) in self._mapping:
464 return
465 self._mapping[id(block)] = ReturnErrno(block.errno)
466
467 def visitTrace(self, block):
468 if id(block) in self._mapping:
469 return
470 self._mapping[id(block)] = Trace()
471
472 def visitLog(self, block):
473 if id(block) in self._mapping:
474 return
475 self._mapping[id(block)] = Log()
476
477 def visitAllow(self, block):
478 if id(block) in self._mapping:
479 return
480 self._mapping[id(block)] = Allow()
481
482 def visitBasicBlock(self, block):
483 if id(block) in self._mapping:
484 return
485 self._mapping[id(block)] = BasicBlock(block.instructions)
486
487 def visitValidateArch(self, block):
488 if id(block) in self._mapping:
489 return
490 self._mapping[id(block)] = ValidateArch(
491 block.arch, self._mapping[id(block.next_block)])
492
493 def visitSyscallEntry(self, block):
494 if id(block) in self._mapping:
495 return
496 self._mapping[id(block)] = SyscallEntry(
497 block.syscall_number,
498 self._mapping[id(block.jt)],
499 self._mapping[id(block.jf)],
500 op=block.op)
501
502 def visitWideAtom(self, block):
503 if id(block) in self._mapping:
504 return
505 self._mapping[id(block)] = WideAtom(
506 block.arg_offset, block.op, block.value, self._mapping[id(
507 block.jt)], self._mapping[id(block.jf)])
508
509 def visitAtom(self, block):
510 if id(block) in self._mapping:
511 return
512 self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
513 self._mapping[id(block.jt)],
514 self._mapping[id(block.jf)])
515
516
517class LoweringVisitor(CopyingVisitor):
518 """A visitor that lowers Atoms into WideAtoms."""
519
520 def __init__(self, *, arch):
521 super().__init__()
522 self._bits = arch.bits
523
524 def visitAtom(self, block):
525 if id(block) in self._mapping:
526 return
527
528 lo = block.value & 0xFFFFFFFF
529 hi = (block.value >> 32) & 0xFFFFFFFF
530
531 lo_block = WideAtom(
532 arg_offset(block.arg_index, False), block.op, lo,
533 self._mapping[id(block.jt)], self._mapping[id(block.jf)])
534
535 if self._bits == 32:
536 self._mapping[id(block)] = lo_block
537 return
538
539 if block.op in (BPF_JGE, BPF_JGT):
540 # hi_1,lo_1 <op> hi_2,lo_2
541 #
542 # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
543 if hi == 0:
544 # Special case: it's not needed to check whether |hi_1 == hi_2|,
545 # because it's true iff the JGT test fails.
546 self._mapping[id(block)] = WideAtom(
547 arg_offset(block.arg_index, True), BPF_JGT, hi,
548 self._mapping[id(block.jt)], lo_block)
549 return
550 hi_eq_block = WideAtom(
551 arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
552 self._mapping[id(block.jf)])
553 self._mapping[id(block)] = WideAtom(
554 arg_offset(block.arg_index, True), BPF_JGT, hi,
555 self._mapping[id(block.jt)], hi_eq_block)
556 return
557 if block.op == BPF_JSET:
558 # hi_1,lo_1 & hi_2,lo_2
559 #
560 # hi_1 & hi_2 || lo_1 & lo_2
561 if hi == 0:
562 # Special case: |hi_1 & hi_2| will never be True, so jump
563 # directly into the |lo_1 & lo_2| case.
564 self._mapping[id(block)] = lo_block
565 return
566 self._mapping[id(block)] = WideAtom(
567 arg_offset(block.arg_index, True), block.op, hi,
568 self._mapping[id(block.jt)], lo_block)
569 return
570
571 assert block.op == BPF_JEQ, block.op
572
573 # hi_1,lo_1 == hi_2,lo_2
574 #
575 # hi_1 == hi_2 && lo_1 == lo_2
576 self._mapping[id(block)] = WideAtom(
577 arg_offset(block.arg_index, True), block.op, hi, lo_block,
578 self._mapping[id(block.jf)])
579
580
581class FlatteningVisitor:
582 """A visitor that flattens a DAG of Block objects."""
583
584 def __init__(self, *, arch, kill_action):
585 self._kill_action = kill_action
586 self._instructions = []
587 self._arch = arch
588 self._offsets = {}
589
590 @property
591 def result(self):
592 return BasicBlock(self._instructions)
593
594 def _distance(self, block):
595 distance = self._offsets[id(block)] + len(self._instructions)
596 assert distance >= 0
597 return distance
598
599 def _emit_load_arg(self, offset):
600 return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
601
602 def _emit_jmp(self, op, value, jt_distance, jf_distance):
603 if jt_distance < 0x100 and jf_distance < 0x100:
604 return [
605 SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
606 value),
607 ]
608 if jt_distance + 1 < 0x100:
609 return [
610 SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
611 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
612 ]
613 if jf_distance + 1 < 0x100:
614 return [
615 SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
616 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
617 ]
618 return [
619 SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
620 SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
621 SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
622 ]
623
624 def visit(self, block):
625 if id(block) in self._offsets:
626 return
627
628 if isinstance(block, BasicBlock):
629 instructions = block.instructions
630 elif isinstance(block, ValidateArch):
631 instructions = [
632 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
633 SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
634 self._distance(block.next_block) + 1, 0,
635 self._arch.arch_nr),
636 ] + self._kill_action.instructions + [
637 SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
638 ]
639 elif isinstance(block, SyscallEntry):
640 instructions = self._emit_jmp(block.op, block.syscall_number,
641 self._distance(block.jt),
642 self._distance(block.jf))
643 elif isinstance(block, WideAtom):
644 instructions = (
645 self._emit_load_arg(block.arg_offset) + self._emit_jmp(
646 block.op, block.value, self._distance(block.jt),
647 self._distance(block.jf)))
648 else:
649 raise Exception('Unknown block type: %r' % block)
650
651 self._instructions = instructions + self._instructions
652 self._offsets[id(block)] = -len(self._instructions)
653 return
Luis Hector Chaveza54812b2018-11-01 20:02:22 -0700654
655
656class ArgFilterForwardingVisitor:
657 """A visitor that forwards visitation to all arg filters."""
658
659 def __init__(self, visitor):
660 self.visitor = visitor
661
662 def visit(self, block):
663 # All arg filters are BasicBlocks.
664 if not isinstance(block, BasicBlock):
665 return
666 # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
667 # want to visit them just yet.
668 if (isinstance(block, KillProcess) or isinstance(block, KillThread)
669 or isinstance(block, Trap) or isinstance(block, ReturnErrno)
670 or isinstance(block, Trace) or isinstance(block, Log)
671 or isinstance(block, Allow)):
672 return
673 block.accept(self.visitor)