Tool luks2hashcat.py code cleanup

pull/3369/head
Konrad Goławski 2 years ago
parent 2bd1861a83
commit e30be9f17c

@ -1,12 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys #
# Author......: See docs/credits.txt
# License.....: MIT
#
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from os import SEEK_SET from os import SEEK_SET
from struct import Struct from struct import Struct
from sys import stderr
from typing import List from typing import List
try: try:
@ -82,7 +85,8 @@ class KeyVersion1:
def __init__(self, active, iterations, salt, af): def __init__(self, active, iterations, salt, af):
self.active = self.Active(active) self.active = self.Active(active)
assert iterations >= 0, "key iterations cannot be less than zero" if (self.active in [self.Active.ENABLED, self.Active.ENABLED_OLD]) and (iterations <= 0):
raise ValueError("key iterations cannot be less than zero")
self.iterations = iterations self.iterations = iterations
self.salt = salt self.salt = salt
self.af = af self.af = af
@ -150,45 +154,62 @@ class HeaderVersion1:
keys: List[KeyVersion1] keys: List[KeyVersion1]
def __init__(self, magic, version, cipher, mode, hash, payload, key_size, digest, salt, iterations, uuid, keys): def __init__(self, magic, version, cipher, mode, hash, payload, key_size, digest, salt, iterations, uuid, keys):
assert magic == self.MAGIC, "Invalid magic bytes" if magic != self.MAGIC:
raise ValueError("invalid magic bytes")
self.magic = magic self.magic = magic
assert version == self.VERSION, "Invalid version" if version != self.VERSION:
raise ValueError("invalid version")
self.version = version self.version = version
if isinstance(cipher, bytes): if isinstance(cipher, bytes):
try: try:
cipher = bytes_to_str(cipher) cipher = bytes_to_str(cipher)
self.cipher = self.Cipher(cipher)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise ValueError("Cannot decode cipher") from e raise ValueError("cannot decode cipher") from e
self.cipher = self.Cipher(cipher) except ValueError as e:
raise ValueError("invalid cipher value") from e
if isinstance(mode, bytes): if isinstance(mode, bytes):
try: try:
mode = bytes_to_str(mode) mode = bytes_to_str(mode)
self.mode = self.Mode(mode)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise ValueError("Cannot decode mode") from e raise ValueError("cannot decode mode") from e
self.mode = self.Mode(mode) except ValueError as e:
raise ValueError("invalid mode value") from e
if isinstance(hash, bytes): if isinstance(hash, bytes):
try: try:
hash = bytes_to_str(hash) hash = bytes_to_str(hash)
self.hash = self.Hash(hash)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise ValueError("Cannot decode hash") from e raise ValueError("cannot decode hash") from e
self.hash = self.Hash(hash) except ValueError as e:
raise ValueError("invalid hash value") from e
self.payload = payload self.payload = payload
self.key_size = self.KeySize(key_size) try:
self.key_size = self.KeySize(key_size)
except ValueError as e:
raise ValueError("invalid key size provided") from e
self.digest = digest self.digest = digest
self.salt = salt self.salt = salt
assert iterations > 0, "Iterations cannot be less or equal to zero" try:
iterations = int(iterations)
except ValueError as e:
raise ValueError("iterations is not a number") from e
if iterations <= 0:
raise ValueError("iterations cannot be less or equal to zero")
self.iterations = iterations self.iterations = iterations
if isinstance(uuid, bytes): if isinstance(uuid, bytes):
try: try:
uuid = bytes_to_str(uuid) uuid = bytes_to_str(uuid)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise ValueError("Cannot decode UUID") from e raise ValueError("cannot decode UUID") from e
self.uuid = uuid self.uuid = uuid
if all(isinstance(key, tuple) for key in keys): if all(isinstance(key, tuple) for key in keys):
keys = [KeyVersion1(*key) for key in keys] keys = [KeyVersion1(*key) for key in keys]
elif all(isinstance(key, dict) for key in keys): elif all(isinstance(key, dict) for key in keys):
keys = [KeyVersion1(**key) for key in keys] keys = [KeyVersion1(**key) for key in keys]
assert all(isinstance(key, KeyVersion1) for key in keys), "Not a key object provided" if any(not isinstance(key, KeyVersion1) for key in keys):
raise ValueError("not a key object provided")
self.keys = keys self.keys = keys
@ -201,16 +222,13 @@ def extract_version1(file):
# prepare structs # prepare structs
key_struct = Struct(">LL32sLL") key_struct = Struct(">LL32sLL")
header_struct = Struct( header_struct = Struct(
">6sH32s32s32sLL20s32sL40s" ">6sH32s32s32sLL20s32sL40s" + str(key_struct.size * KEYS_COUNT) + "s" + str(PADDING_LENGTH) + "x"
+ str(key_struct.size * KEYS_COUNT)
+ "s"
+ str(PADDING_LENGTH)
+ "x"
) )
# read header # read header
header = file.read(header_struct.size) header = file.read(header_struct.size)
assert len(header) == header_struct.size, "File contains less data than needed" if len(header) < header_struct.size:
raise ValueError("file contains less data than needed")
# convert bytes into temporary header # convert bytes into temporary header
header = header_struct.unpack(header) header = header_struct.unpack(header)
@ -224,7 +242,8 @@ def extract_version1(file):
for key in tmp_keys: for key in tmp_keys:
file.seek(key.material_offset * SECTOR_SIZE, SEEK_SET) file.seek(key.material_offset * SECTOR_SIZE, SEEK_SET)
af = file.read(header.key_bytes * key.stripes) af = file.read(header.key_bytes * key.stripes)
assert len(af) == (header.key_bytes * key.stripes), "File contains less data than needed" if len(af) < (header.key_bytes * key.stripes):
raise ValueError("file contains less data than needed")
key = KeyVersion1(key.active, key.iterations, key.salt, af) key = KeyVersion1(key.active, key.iterations, key.salt, af)
keys.append(key) keys.append(key)
@ -232,7 +251,8 @@ def extract_version1(file):
# read payload # read payload
file.seek(header.payload_offset * SECTOR_SIZE, SEEK_SET) file.seek(header.payload_offset * SECTOR_SIZE, SEEK_SET)
payload = file.read(PAYLOAD_SIZE) payload = file.read(PAYLOAD_SIZE)
assert len(payload) == PAYLOAD_SIZE, "File contains less data than needed" if len(payload) < PAYLOAD_SIZE:
raise ValueError("file contains less data than needed")
# convert into header # convert into header
header = HeaderVersion1( header = HeaderVersion1(
@ -275,54 +295,52 @@ def extract_version1(file):
break break
else: else:
# all keys are disabled # all keys are disabled
raise ValueError("All keys are disabled") raise ValueError("all keys are disabled")
# main # main
def main(args): if __name__ == "__main__":
# prepare parser and parse args # prepare parser and parse args
parser = ArgumentParser(description="luks2hashcat extraction tool") parser = ArgumentParser(description="luks2hashcat extraction tool")
parser.add_argument("path", type=str, help="path to LUKS container") parser.add_argument("path", type=str, help="path to LUKS container")
args = parser.parse_args(args) args = parser.parse_args()
# prepare struct # prepare struct
header_struct = Struct(">6sH") header_struct = Struct(">6sH")
with open(args.path, "rb") as file: try:
# read pre header with open(args.path, "rb") as file:
header = file.read(header_struct.size) # read pre header
assert len(header) == header_struct.size, "File contains less data than needed" header = file.read(header_struct.size)
if len(header) < header_struct.size:
# convert bytes into temporary pre header parser.error("file contains less data than needed")
header = header_struct.unpack(header)
header = TmpHeaderPre(*header) # convert bytes into temporary pre header
header = header_struct.unpack(header)
# check magic bytes header = TmpHeaderPre(*header)
magic_bytes = {
HeaderVersion1.MAGIC, # check magic bytes
} magic_bytes = {
assert header.magic in magic_bytes, "Improper magic bytes" HeaderVersion1.MAGIC,
# back to start of the file
file.seek(0, SEEK_SET)
# extract with proper function
try:
mapping = {
HeaderVersion1.VERSION: extract_version1,
} }
extract = mapping[header.version] if header.magic not in magic_bytes:
extract(file) parser.error("improper magic bytes")
except KeyError as e:
raise ValueError("Unsupported version") from e
# back to start of the file
file.seek(0, SEEK_SET)
if __name__ == "__main__": # extract with proper function
try: try:
main(sys.argv[1:]) mapping = {
HeaderVersion1.VERSION: extract_version1,
}
extract = mapping[header.version]
extract(file)
except KeyError as e:
raise ValueError("unsupported version") from e
except IOError as e: except IOError as e:
print('Error:', e.strerror, file=stderr) parser.error(e.strerror.lower())
except (AssertionError, ValueError) as e: except ValueError as e:
print('Error:', e, file=stderr) parser.error(str(e))

Loading…
Cancel
Save