diff --git a/tools/luks2hashcat.py b/tools/luks2hashcat.py index 5eab81cb1..2cda461b5 100755 --- a/tools/luks2hashcat.py +++ b/tools/luks2hashcat.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 -import sys +# +# Author......: See docs/credits.txt +# License.....: MIT +# + from argparse import ArgumentParser from collections import namedtuple from dataclasses import dataclass from os import SEEK_SET from struct import Struct -from sys import stderr from typing import List try: @@ -82,7 +85,8 @@ class KeyVersion1: def __init__(self, active, iterations, salt, af): self.active = self.Active(active) - assert iterations >= 0, "key iterations cannot be less than zero" + if (self.active in [self.Active.ENABLED, self.Active.ENABLED_OLD]) and (iterations <= 0): + raise ValueError("key iterations cannot be less than zero") self.iterations = iterations self.salt = salt self.af = af @@ -150,45 +154,62 @@ class HeaderVersion1: keys: List[KeyVersion1] def __init__(self, magic, version, cipher, mode, hash, payload, key_size, digest, salt, iterations, uuid, keys): - assert magic == self.MAGIC, "Invalid magic bytes" + if magic != self.MAGIC: + raise ValueError("invalid magic bytes") self.magic = magic - assert version == self.VERSION, "Invalid version" + if version != self.VERSION: + raise ValueError("invalid version") self.version = version if isinstance(cipher, bytes): try: cipher = bytes_to_str(cipher) + self.cipher = self.Cipher(cipher) except UnicodeDecodeError as e: - raise ValueError("Cannot decode cipher") from e - self.cipher = self.Cipher(cipher) + raise ValueError("cannot decode cipher") from e + except ValueError as e: + raise ValueError("invalid cipher value") from e if isinstance(mode, bytes): try: mode = bytes_to_str(mode) + self.mode = self.Mode(mode) except UnicodeDecodeError as e: - raise ValueError("Cannot decode mode") from e - self.mode = self.Mode(mode) + raise ValueError("cannot decode mode") from e + except ValueError as e: + raise ValueError("invalid mode value") from e if isinstance(hash, bytes): try: hash = bytes_to_str(hash) + self.hash = self.Hash(hash) except UnicodeDecodeError as e: - raise ValueError("Cannot decode hash") from e - self.hash = self.Hash(hash) + raise ValueError("cannot decode hash") from e + except ValueError as e: + raise ValueError("invalid hash value") from e self.payload = payload - self.key_size = self.KeySize(key_size) + try: + self.key_size = self.KeySize(key_size) + except ValueError as e: + raise ValueError("invalid key size provided") from e self.digest = digest self.salt = salt - assert iterations > 0, "Iterations cannot be less or equal to zero" + try: + iterations = int(iterations) + except ValueError as e: + raise ValueError("iterations is not a number") from e + if iterations <= 0: + raise ValueError("iterations cannot be less or equal to zero") self.iterations = iterations if isinstance(uuid, bytes): try: uuid = bytes_to_str(uuid) except UnicodeDecodeError as e: - raise ValueError("Cannot decode UUID") from e + raise ValueError("cannot decode UUID") from e self.uuid = uuid if all(isinstance(key, tuple) for key in keys): keys = [KeyVersion1(*key) for key in keys] elif all(isinstance(key, dict) for key in keys): keys = [KeyVersion1(**key) for key in keys] - assert all(isinstance(key, KeyVersion1) for key in keys), "Not a key object provided" + if any(not isinstance(key, KeyVersion1) for key in keys): + raise ValueError("not a key object provided") self.keys = keys @@ -201,16 +222,13 @@ def extract_version1(file): # prepare structs key_struct = Struct(">LL32sLL") header_struct = Struct( - ">6sH32s32s32sLL20s32sL40s" - + str(key_struct.size * KEYS_COUNT) - + "s" - + str(PADDING_LENGTH) - + "x" + ">6sH32s32s32sLL20s32sL40s" + str(key_struct.size * KEYS_COUNT) + "s" + str(PADDING_LENGTH) + "x" ) # read header header = file.read(header_struct.size) - assert len(header) == header_struct.size, "File contains less data than needed" + if len(header) < header_struct.size: + raise ValueError("file contains less data than needed") # convert bytes into temporary header header = header_struct.unpack(header) @@ -224,7 +242,8 @@ def extract_version1(file): for key in tmp_keys: file.seek(key.material_offset * SECTOR_SIZE, SEEK_SET) af = file.read(header.key_bytes * key.stripes) - assert len(af) == (header.key_bytes * key.stripes), "File contains less data than needed" + if len(af) < (header.key_bytes * key.stripes): + raise ValueError("file contains less data than needed") key = KeyVersion1(key.active, key.iterations, key.salt, af) keys.append(key) @@ -232,7 +251,8 @@ def extract_version1(file): # read payload file.seek(header.payload_offset * SECTOR_SIZE, SEEK_SET) payload = file.read(PAYLOAD_SIZE) - assert len(payload) == PAYLOAD_SIZE, "File contains less data than needed" + if len(payload) < PAYLOAD_SIZE: + raise ValueError("file contains less data than needed") # convert into header header = HeaderVersion1( @@ -275,54 +295,52 @@ def extract_version1(file): break else: # all keys are disabled - raise ValueError("All keys are disabled") + raise ValueError("all keys are disabled") # main -def main(args): +if __name__ == "__main__": # prepare parser and parse args parser = ArgumentParser(description="luks2hashcat extraction tool") parser.add_argument("path", type=str, help="path to LUKS container") - args = parser.parse_args(args) + args = parser.parse_args() # prepare struct header_struct = Struct(">6sH") - with open(args.path, "rb") as file: - # read pre header - header = file.read(header_struct.size) - assert len(header) == header_struct.size, "File contains less data than needed" - - # convert bytes into temporary pre header - header = header_struct.unpack(header) - header = TmpHeaderPre(*header) - - # check magic bytes - magic_bytes = { - HeaderVersion1.MAGIC, - } - assert header.magic in magic_bytes, "Improper magic bytes" - - # back to start of the file - file.seek(0, SEEK_SET) - - # extract with proper function - try: - mapping = { - HeaderVersion1.VERSION: extract_version1, + try: + with open(args.path, "rb") as file: + # read pre header + header = file.read(header_struct.size) + if len(header) < header_struct.size: + parser.error("file contains less data than needed") + + # convert bytes into temporary pre header + header = header_struct.unpack(header) + header = TmpHeaderPre(*header) + + # check magic bytes + magic_bytes = { + HeaderVersion1.MAGIC, } - extract = mapping[header.version] - extract(file) - except KeyError as e: - raise ValueError("Unsupported version") from e + if header.magic not in magic_bytes: + parser.error("improper magic bytes") + # back to start of the file + file.seek(0, SEEK_SET) -if __name__ == "__main__": - try: - main(sys.argv[1:]) + # extract with proper function + try: + mapping = { + HeaderVersion1.VERSION: extract_version1, + } + extract = mapping[header.version] + extract(file) + except KeyError as e: + raise ValueError("unsupported version") from e except IOError as e: - print('Error:', e.strerror, file=stderr) - except (AssertionError, ValueError) as e: - print('Error:', e, file=stderr) + parser.error(e.strerror.lower()) + except ValueError as e: + parser.error(str(e))