From 9caea6d413794135caf99591c5d97d6a02716b14 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 2 Nov 2018 16:21:28 +0100 Subject: [PATCH] tx_api: rework API, separate caching functionality to test support --- trezorlib/coins.py | 17 +- trezorlib/tests/support/tx_cache.py | 45 ++++++ trezorlib/tx_api.py | 231 ++++++++++++++-------------- 3 files changed, 159 insertions(+), 134 deletions(-) create mode 100644 trezorlib/tests/support/tx_cache.py diff --git a/trezorlib/coins.py b/trezorlib/coins.py index 72ff2729e..4a920e8d5 100644 --- a/trezorlib/coins.py +++ b/trezorlib/coins.py @@ -17,7 +17,7 @@ import json import os.path -from .tx_api import TxApiInsight +from .tx_api import TxApi COINS_JSON = os.path.join(os.path.dirname(__file__), "coins.json") @@ -34,19 +34,6 @@ def _load_coins_json(): return json.load(coins_json) -def _insight_for_coin(coin): - url = next(iter(coin["blockbook"] + coin["bitcore"]), None) - if not url: - return None - zcash = coin["coin_name"].lower().startswith("zcash") - bip115 = coin["bip115"] - decred = coin["decred"] - network = "insight_{}".format(coin["coin_name"].lower().replace(" ", "_")) - return TxApiInsight( - network=network, url=url, zcash=zcash, bip115=bip115, decred=decred - ) - - # exported variables __all__ = ["by_name", "slip44", "tx_api"] @@ -58,7 +45,7 @@ except Exception as e: slip44 = {name: coin["slip44"] for name, coin in by_name.items()} tx_api = { - name: _insight_for_coin(coin) + name: TxApi(coin) for name, coin in by_name.items() if coin["blockbook"] or coin["bitcore"] } diff --git a/trezorlib/tests/support/tx_cache.py b/trezorlib/tests/support/tx_cache.py new file mode 100644 index 000000000..291b739cf --- /dev/null +++ b/trezorlib/tests/support/tx_cache.py @@ -0,0 +1,45 @@ +import decimal +import json +import os.path + +from trezorlib import coins +from trezorlib.tx_api import json_to_tx + +CACHE_PATH = os.path.join(os.path.dirname(__file__), "..", "txcache") + + +def tx_cache(coin_name, allow_fetch=True): + coin_data = coins.by_name[coin_name] + fetch = coins.tx_api[coin_name].get_tx_data if allow_fetch else None + return TxCache(CACHE_PATH, coin_data, fetch) + + +class TxCache: + def __init__(self, path, coin_data, fetch=None): + self.path = path + self.coin_data = coin_data + self.fetch = fetch + + coin_slug = coin_data["coin_name"].lower().replace(" ", "_") + prefix = "insight_" + coin_slug + "_tx_" + self.file_pattern = os.path.join(self.path, prefix + "{}.json") + + def get_tx(self, txhash): + cache_file = self.file_pattern.format(txhash) + + try: + with open(cache_file) as f: + data = json.load(f, parse_float=decimal.Decimal) + return json_to_tx(self.coin_data, data) + except Exception as e: + if self.fetch is None: + raise Exception("Unhandled cache miss") from e + + # cache miss, try to use backend + data = self.fetch(txhash) + with open(cache_file, "w") as f: + json.dump(data, f) + return json_to_tx(self.coin_data, data) + + def __getitem__(self, key): + return self.get_tx(key.hex()) diff --git a/trezorlib/tx_api.py b/trezorlib/tx_api.py index 87b558998..68d4f62c6 100644 --- a/trezorlib/tx_api.py +++ b/trezorlib/tx_api.py @@ -16,140 +16,133 @@ import json from decimal import Decimal +import random +from typing import Mapping, Any import requests -from . import messages as proto +from . import messages cache_dir = None -class TxApi(object): - def __init__(self, network, url=None): - self.network = network - self.url = url - - def get_url(self, *args): - return "/".join(map(str, [self.url, "api"] + list(args))) - - def fetch_json(self, resource, resourceid): - global cache_dir - if cache_dir: - cache_file = "%s/%s_%s_%s.json" % ( - cache_dir, - self.network, - resource, - resourceid, - ) - try: # looking into cache first - j = json.load(open(cache_file), parse_float=str) - return j - except Exception: - pass - - if not self.url: - raise RuntimeError("No URL specified and tx not in cache") - - try: - url = self.get_url(resource, resourceid) - r = requests.get(url, headers={"User-agent": "Mozilla/5.0"}) - j = r.json(parse_float=str) - except Exception: - raise RuntimeError("URL error: %s" % url) - if cache_dir and cache_file: - try: # saving into cache - json.dump(j, open(cache_file, "w")) - except Exception: - pass - return j - - def get_tx(self, txhash): - raise NotImplementedError - - -class TxApiInsight(TxApi): - def __init__(self, network, url=None, zcash=None, bip115=False, decred=False): - super().__init__(network, url) - self.zcash = zcash - self.bip115 = bip115 - self.decred = decred - if url: +def is_zcash(coin): + return coin["coin_name"].lower().startswith("zcash") + + +def is_capricoin(coin): + return coin["coin_name"].lower().startswith("capricoin") + + +def _json_to_input(coin, vin): + i = messages.TxInputType() + if "coinbase" in vin: + i.prev_hash = b"\0" * 32 + i.prev_index = 0xFFFFFFFF # signed int -1 + i.script_sig = bytes.fromhex(vin["coinbase"]) + i.sequence = vin["sequence"] + + else: + i.prev_hash = bytes.fromhex(vin["txid"]) + i.prev_index = vin["vout"] + i.script_sig = bytes.fromhex(vin["scriptSig"]["hex"]) + i.sequence = vin["sequence"] + + if coin["decred"]: + i.decred_tree = vin["tree"] + # TODO: support amountIn, blockHeight, blockIndex + + return i + + +def _json_to_bin_output(coin, vout): + o = messages.TxOutputBinType() + o.amount = int(Decimal(vout["value"]) * 100000000) + o.script_pubkey = bytes.fromhex(vout["scriptPubKey"]["hex"]) + if coin["bip115"] and o.script_pubkey[-1] == 0xB4: + # Verify if coin implements replay protection bip115 and script includes + # checkblockatheight opcode. 0xb4 - is op_code (OP_CHECKBLOCKATHEIGHT) + # <32-byte block hash> <3-byte block height> + tail = o.script_pubkey[-38:] + o.block_hash = tail[1:33] # <32-byte block hash> + o.block_height = int.from_bytes(tail[34:37], "little") # <3-byte block height> + if coin["decred"]: + o.decred_script_version = vout["version"] + + return o + + +def json_to_tx(coin, data): + t = messages.TransactionType() + t.version = data["version"] + t.lock_time = data.get("locktime") + + if is_capricoin(coin): + t.timestamp = data["time"] + + if coin["decred"]: + t.expiry = data["expiry"] + + if is_zcash(coin): + t.overwintered = data.get("fOverwintered", False) + t.expiry = data.get("nExpiryHeight", None) + t.version_group_id = data.get("nVersionGroupId", None) + + t.inputs = [_json_to_input(coin, vin) for vin in data["vin"]] + t.bin_outputs = [_json_to_bin_output(coin, vout) for vout in data["vout"]] + + # zcash extra data + if is_zcash(coin) and t.version >= 2: + joinsplit_cnt = len(data["vjoinsplit"]) + if joinsplit_cnt == 0: + t.extra_data = b"\x00" + elif joinsplit_cnt >= 253: + # we assume cnt < 253, so we can treat varIntLen(cnt) as 1 + raise ValueError("Too many joinsplits") + elif "hex" not in data: + raise ValueError("Raw TX data required for Zcash joinsplit transaction") + else: + rawtx = bytes.fromhex(data["hex"]) + extra_data_len = 1 + joinsplit_cnt * 1802 + 32 + 64 + t.extra_data = rawtx[-extra_data_len:] + + return t + + +class TxApi: + def __init__(self, coin_data): + self.coin_data = coin_data + if coin_data["blockbook"]: + self.url = random.choice(coin_data["blockbook"]) + self.pushtx_url = self.url + "/sendtx" + elif coin_data["bitcore"]: + self.url = random.choice(coin_data["bitcore"]) self.pushtx_url = self.url + "/tx/send" + else: + raise ValueError("No API URL in coin data") + + def fetch_json(self, *path, **params): + url = self.url + "/api/" + "/".join(map(str, path)) + return requests.get(url, params=params).json(parse_float=Decimal) def get_block_hash(self, block_number): j = self.fetch_json("block-index", block_number) return bytes.fromhex(j["blockHash"]) def current_height(self): - r = requests.get(self.get_url("status?q=getBlockCount")) - j = r.json(parse_float=str) - block_height = j["info"]["blocks"] - return block_height + j = self.fetch_json("status", q="getBlockCount") + return j["info"]["blocks"] - def get_tx(self, txhash): + def __getitem__(self, txhash): + return self.get_tx(txhash.hex()) + def get_tx_data(self, txhash): data = self.fetch_json("tx", txhash) + if is_zcash(self.coin_data) and data.get("vjoinsplit") and "hex" not in data: + j = self.fetch_json("rawtx", txhash) + data["hex"] = j["rawtx"] + return data - t = proto.TransactionType() - t.version = data["version"] - t.lock_time = data.get("locktime") - - if self.network == "insight_capricoin": - t.timestamp = data["time"] - - if self.decred: - t.expiry = data["expiry"] - - if self.zcash: - t.overwintered = data.get("fOverwintered", False) - t.expiry = data.get("nExpiryHeight", None) - t.version_group_id = data.get("nVersionGroupId", None) - - for vin in data["vin"]: - i = t._add_inputs() - if "coinbase" in vin.keys(): - i.prev_hash = b"\0" * 32 - i.prev_index = 0xFFFFFFFF # signed int -1 - i.script_sig = bytes.fromhex(vin["coinbase"]) - i.sequence = vin["sequence"] - - else: - i.prev_hash = bytes.fromhex(vin["txid"]) - i.prev_index = vin["vout"] - i.script_sig = bytes.fromhex(vin["scriptSig"]["hex"]) - i.sequence = vin["sequence"] - - if self.decred: - i.decred_tree = vin["tree"] - # TODO: support amountIn, blockHeight, blockIndex - - for vout in data["vout"]: - o = t._add_bin_outputs() - o.amount = int(Decimal(vout["value"]) * 100000000) - o.script_pubkey = bytes.fromhex(vout["scriptPubKey"]["hex"]) - if self.bip115 and o.script_pubkey[-1] == 0xB4: - # Verify if coin implements replay protection bip115 and script includes checkblockatheight opcode. 0xb4 - is op_code (OP_CHECKBLOCKATHEIGHT) - # <32-byte block hash> <3-byte block height> - tail = o.script_pubkey[-38:] - o.block_hash = tail[1:33] # <32-byte block hash> - o.block_height = int.from_bytes( - tail[34:37], byteorder="little" - ) # <3-byte block height> - if self.decred: - o.decred_script_version = vout["version"] - - if self.zcash: - if t.version >= 2: - joinsplit_cnt = len(data["vjoinsplit"]) - if joinsplit_cnt == 0: - t.extra_data = b"\x00" - else: - if joinsplit_cnt >= 253: - # we assume cnt < 253, so we can treat varIntLen(cnt) as 1 - raise ValueError("Too many joinsplits") - extra_data_len = 1 + joinsplit_cnt * 1802 + 32 + 64 - raw = self.fetch_json("rawtx", txhash) - raw = bytes.fromhex(raw["rawtx"]) - t.extra_data = raw[-extra_data_len:] - - return t + def get_tx(self, txhash): + data = self.get_tx_data(txhash) + return json_to_tx(self.coin_data, data)