Tool luks2hashcat.py code cleanup

pull/3369/head
Konrad Goławski 2 years ago
parent 2bd1861a83
commit e30be9f17c

@ -1,12 +1,15 @@
#!/usr/bin/env python3
import sys
#
# Author......: See docs/credits.txt
# License.....: MIT
#
from argparse import ArgumentParser
from collections import namedtuple
from dataclasses import dataclass
from os import SEEK_SET
from struct import Struct
from sys import stderr
from typing import List
try:
@ -82,7 +85,8 @@ class KeyVersion1:
def __init__(self, active, iterations, salt, af):
self.active = self.Active(active)
assert iterations >= 0, "key iterations cannot be less than zero"
if (self.active in [self.Active.ENABLED, self.Active.ENABLED_OLD]) and (iterations <= 0):
raise ValueError("key iterations cannot be less than zero")
self.iterations = iterations
self.salt = salt
self.af = af
@ -150,45 +154,62 @@ class HeaderVersion1:
keys: List[KeyVersion1]
def __init__(self, magic, version, cipher, mode, hash, payload, key_size, digest, salt, iterations, uuid, keys):
assert magic == self.MAGIC, "Invalid magic bytes"
if magic != self.MAGIC:
raise ValueError("invalid magic bytes")
self.magic = magic
assert version == self.VERSION, "Invalid version"
if version != self.VERSION:
raise ValueError("invalid version")
self.version = version
if isinstance(cipher, bytes):
try:
cipher = bytes_to_str(cipher)
self.cipher = self.Cipher(cipher)
except UnicodeDecodeError as e:
raise ValueError("Cannot decode cipher") from e
self.cipher = self.Cipher(cipher)
raise ValueError("cannot decode cipher") from e
except ValueError as e:
raise ValueError("invalid cipher value") from e
if isinstance(mode, bytes):
try:
mode = bytes_to_str(mode)
self.mode = self.Mode(mode)
except UnicodeDecodeError as e:
raise ValueError("Cannot decode mode") from e
self.mode = self.Mode(mode)
raise ValueError("cannot decode mode") from e
except ValueError as e:
raise ValueError("invalid mode value") from e
if isinstance(hash, bytes):
try:
hash = bytes_to_str(hash)
self.hash = self.Hash(hash)
except UnicodeDecodeError as e:
raise ValueError("Cannot decode hash") from e
self.hash = self.Hash(hash)
raise ValueError("cannot decode hash") from e
except ValueError as e:
raise ValueError("invalid hash value") from e
self.payload = payload
self.key_size = self.KeySize(key_size)
try:
self.key_size = self.KeySize(key_size)
except ValueError as e:
raise ValueError("invalid key size provided") from e
self.digest = digest
self.salt = salt
assert iterations > 0, "Iterations cannot be less or equal to zero"
try:
iterations = int(iterations)
except ValueError as e:
raise ValueError("iterations is not a number") from e
if iterations <= 0:
raise ValueError("iterations cannot be less or equal to zero")
self.iterations = iterations
if isinstance(uuid, bytes):
try:
uuid = bytes_to_str(uuid)
except UnicodeDecodeError as e:
raise ValueError("Cannot decode UUID") from e
raise ValueError("cannot decode UUID") from e
self.uuid = uuid
if all(isinstance(key, tuple) for key in keys):
keys = [KeyVersion1(*key) for key in keys]
elif all(isinstance(key, dict) for key in keys):
keys = [KeyVersion1(**key) for key in keys]
assert all(isinstance(key, KeyVersion1) for key in keys), "Not a key object provided"
if any(not isinstance(key, KeyVersion1) for key in keys):
raise ValueError("not a key object provided")
self.keys = keys
@ -201,16 +222,13 @@ def extract_version1(file):
# prepare structs
key_struct = Struct(">LL32sLL")
header_struct = Struct(
">6sH32s32s32sLL20s32sL40s"
+ str(key_struct.size * KEYS_COUNT)
+ "s"
+ str(PADDING_LENGTH)
+ "x"
">6sH32s32s32sLL20s32sL40s" + str(key_struct.size * KEYS_COUNT) + "s" + str(PADDING_LENGTH) + "x"
)
# read header
header = file.read(header_struct.size)
assert len(header) == header_struct.size, "File contains less data than needed"
if len(header) < header_struct.size:
raise ValueError("file contains less data than needed")
# convert bytes into temporary header
header = header_struct.unpack(header)
@ -224,7 +242,8 @@ def extract_version1(file):
for key in tmp_keys:
file.seek(key.material_offset * SECTOR_SIZE, SEEK_SET)
af = file.read(header.key_bytes * key.stripes)
assert len(af) == (header.key_bytes * key.stripes), "File contains less data than needed"
if len(af) < (header.key_bytes * key.stripes):
raise ValueError("file contains less data than needed")
key = KeyVersion1(key.active, key.iterations, key.salt, af)
keys.append(key)
@ -232,7 +251,8 @@ def extract_version1(file):
# read payload
file.seek(header.payload_offset * SECTOR_SIZE, SEEK_SET)
payload = file.read(PAYLOAD_SIZE)
assert len(payload) == PAYLOAD_SIZE, "File contains less data than needed"
if len(payload) < PAYLOAD_SIZE:
raise ValueError("file contains less data than needed")
# convert into header
header = HeaderVersion1(
@ -275,54 +295,52 @@ def extract_version1(file):
break
else:
# all keys are disabled
raise ValueError("All keys are disabled")
raise ValueError("all keys are disabled")
# main
def main(args):
if __name__ == "__main__":
# prepare parser and parse args
parser = ArgumentParser(description="luks2hashcat extraction tool")
parser.add_argument("path", type=str, help="path to LUKS container")
args = parser.parse_args(args)
args = parser.parse_args()
# prepare struct
header_struct = Struct(">6sH")
with open(args.path, "rb") as file:
# read pre header
header = file.read(header_struct.size)
assert len(header) == header_struct.size, "File contains less data than needed"
# convert bytes into temporary pre header
header = header_struct.unpack(header)
header = TmpHeaderPre(*header)
# check magic bytes
magic_bytes = {
HeaderVersion1.MAGIC,
}
assert header.magic in magic_bytes, "Improper magic bytes"
# back to start of the file
file.seek(0, SEEK_SET)
# extract with proper function
try:
mapping = {
HeaderVersion1.VERSION: extract_version1,
try:
with open(args.path, "rb") as file:
# read pre header
header = file.read(header_struct.size)
if len(header) < header_struct.size:
parser.error("file contains less data than needed")
# convert bytes into temporary pre header
header = header_struct.unpack(header)
header = TmpHeaderPre(*header)
# check magic bytes
magic_bytes = {
HeaderVersion1.MAGIC,
}
extract = mapping[header.version]
extract(file)
except KeyError as e:
raise ValueError("Unsupported version") from e
if header.magic not in magic_bytes:
parser.error("improper magic bytes")
# back to start of the file
file.seek(0, SEEK_SET)
if __name__ == "__main__":
try:
main(sys.argv[1:])
# extract with proper function
try:
mapping = {
HeaderVersion1.VERSION: extract_version1,
}
extract = mapping[header.version]
extract(file)
except KeyError as e:
raise ValueError("unsupported version") from e
except IOError as e:
print('Error:', e.strerror, file=stderr)
except (AssertionError, ValueError) as e:
print('Error:', e, file=stderr)
parser.error(e.strerror.lower())
except ValueError as e:
parser.error(str(e))

Loading…
Cancel
Save