RWCTF 2022 Crypto SMT

云水遥

稀疏默克尔树。

题面

SparseMerkleTree.sol
pragma solidity >=0.8.0 <0.9.0;

uint256 constant SMT_STACK_SIZE = 32;
uint256 constant DEPTH = 160;

library SMT {
    struct Leaf {
        address key;
        uint8 value;
    }

    enum Mode {
        BlackList,
        WhiteList
    }

    enum Method {
        Insert,
        Delete
    }

    function init() internal pure returns (bytes32) {
        return 0;
    }
    
    function calcLeaf(Leaf memory a) internal pure returns (bytes32) {
        if (a.value == 0) {
            return 0;
        } else {
            return keccak256(abi.encode(a.key, a.value));
        }
    }

    function merge(bytes32 l, bytes32 r) internal pure returns (bytes32) {
        if (l == 0) {
            return r;
        } else if (r == 0) {
            return l;
        } else {
            return keccak256(abi.encode(l, r));
        }
    }

    function verifyByMode(
        bytes32[] memory _proofs,
        address[] memory _targets,
        bytes32 _expectedRoot,
        Mode _mode
    ) internal pure returns (bool) {
        Leaf[] memory leaves = new Leaf[](_targets.length);
        for (uint256 i = 0; i < _targets.length; i++) {
            leaves[i] = Leaf({key: _targets[i], value: uint8(_mode)});
        }
        return verify(_proofs, leaves, _expectedRoot);
    }

    function verify(
        bytes32[] memory _proofs,
        Leaf[] memory _leaves,
        bytes32 _expectedRoot
    ) internal pure returns (bool) {
        return (calcRoot(_proofs, _leaves, _expectedRoot) == _expectedRoot);
    }

    function updateSingleTarget(
        bytes32[] memory _proofs,
        address _target,
        bytes32 _prevRoot,
        Method _method
    ) internal pure returns (bytes32) {
        Leaf[] memory nextLeaves = new Leaf[](1);
        Leaf[] memory prevLeaves = new Leaf[](1);
        nextLeaves[0] = Leaf({key: _target, value: uint8(_method) ^ 1});
        prevLeaves[0] = Leaf({key: _target, value: uint8(_method)});
        return update(_proofs, nextLeaves, prevLeaves, _prevRoot);
    }

    function update(
        bytes32[] memory _proofs,
        Leaf[] memory _nextLeaves,
        Leaf[] memory _prevLeaves,
        bytes32 _prevRoot
    ) internal pure returns (bytes32) {
        require(verify(_proofs, _prevLeaves, _prevRoot), "update proof not valid");
        return calcRoot(_proofs, _nextLeaves, _prevRoot);
    }

    function checkGroupSorted(Leaf[] memory _leaves) internal pure returns (bool) {
        require(_leaves.length >= 1);
        uint160 temp = 0;
        for (uint256 i = 0; i < _leaves.length; i++) {
            if (temp >= uint160(_leaves[i].key)) {
                return false;
            }
            temp = uint160(_leaves[i].key);
        }
        return true;
    }

    function getBit(uint160 key, uint256 height) internal pure returns (uint256) {
        require(height <= DEPTH);
        return (key >> height) & 1;
    }

    function parentPath(uint160 key, uint256 height) internal pure returns (uint160) {
        require(height <= DEPTH);
        return copyBit(key, height + 1);
    }

    function copyBit(uint160 key, uint256 height) internal pure returns (uint160) {
        require(height <= DEPTH);
        return ((key >> height) << height);
    }

    function calcRoot(
        bytes32[] memory _proofs,
        Leaf[] memory _leaves,
        bytes32 _root
    ) internal pure returns (bytes32) {
        require(checkGroupSorted(_leaves));
        uint160[] memory stackKeys = new uint160[](SMT_STACK_SIZE);
        bytes32[] memory stackValues = new bytes32[](SMT_STACK_SIZE);
        uint256 proofIndex = 0;
        uint256 leaveIndex = 0;
        uint256 stackTop = 0;

        while (proofIndex < _proofs.length) {
            if (uint256(_proofs[proofIndex]) == 0x4c) {
                proofIndex++;
                require(stackTop < SMT_STACK_SIZE);
                require(leaveIndex < _leaves.length);
                stackKeys[stackTop] = uint160(_leaves[leaveIndex].key);
                stackValues[stackTop] = calcLeaf(_leaves[leaveIndex]);
                stackTop++;
                leaveIndex++;
            } else if (uint256(_proofs[proofIndex]) == 0x50) {
                proofIndex++;
                require(stackTop != 0);
                require(proofIndex + 2 <= _proofs.length);

                uint256 height = uint256(_proofs[proofIndex++]);
                bytes32 currentProof = _proofs[proofIndex++];
                require(currentProof != _root);
                if (getBit(stackKeys[stackTop - 1], height) == 1) {
                    stackValues[stackTop - 1] = merge(currentProof, stackValues[stackTop - 1]);
                } else {
                    stackValues[stackTop - 1] = merge(stackValues[stackTop - 1], currentProof);
                }
                stackKeys[stackTop - 1] = parentPath(stackKeys[stackTop - 1], height);
            } else if (uint256(_proofs[proofIndex]) == 0x48) {
                proofIndex++;
                require(stackTop >= 2);
                require(proofIndex < _proofs.length);
                uint256 height = uint256(_proofs[proofIndex++]);
                uint256 aSet = getBit(stackKeys[stackTop - 2], height);
                uint256 bSet = getBit(stackKeys[stackTop - 1], height);
                stackKeys[stackTop - 2] = parentPath(stackKeys[stackTop - 2], height);
                stackKeys[stackTop - 1] = parentPath(stackKeys[stackTop - 1], height);
                require(stackKeys[stackTop - 2] == stackKeys[stackTop - 1] && aSet != bSet);

                if (aSet == 1) {
                    stackValues[stackTop - 2] = merge(
                        stackValues[stackTop - 1],
                        stackValues[stackTop - 2]
                    );
                } else {
                    stackValues[stackTop - 2] = merge(
                        stackValues[stackTop - 2],
                        stackValues[stackTop - 1]
                    );
                }
                stackTop -= 1;
            } else {
                revert();
            }
        }
        require(leaveIndex == _leaves.length);
        require(stackTop == 1);
        return stackValues[0];
    }
}

contract Test {
    function test1() public returns (bytes32) {
        address key = address(0x2);
        uint8 value = 0x1;
        return keccak256(abi.encode(key, value));
    }

    function test2() public returns (bytes32) {
        bytes32 a = bytes32(uint256(0x40));
        bytes32 b = bytes32(uint256(0x42));
        return SMT.merge(a, b);
    }
}
Tr3asur3Hun7er.sol
pragma solidity >=0.8.0 <0.9.0;

import {SMT} from "./SparseMerkleTree.sol";

contract TreasureHunter {
    bytes32 public root;
    SMT.Mode public smtMode = SMT.Mode.WhiteList;
    bool public solved = false;

    mapping(address => bool) public haveKey;
    mapping(address => bool) public haveTreasureChest;

    event FindKey(address indexed _from);
    event PickupTreasureChest(address indexed _from);
    event OpenTreasureChest(address indexed _from);

    constructor() public {
        root = SMT.init();
        _init();
    }

    function _init() internal {
        address[] memory hunters = new address[](8);
        hunters[0] = 0x0bc529c00C6401aEF6D220BE8C6Ea1667F6Ad93e;
        hunters[1] = 0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45;
        hunters[2] = 0x6B175474E89094C44Da98b954EedeAC495271d0F;
        hunters[3] = 0x6B3595068778DD592e39A122f4f5a5cF09C90fE2;
        hunters[4] = 0xAb5801a7D398351b8bE11C439e05C5B3259aeC9B;
        hunters[5] = 0xc00e94Cb662C3520282E6f5717214004A7f26888;
        hunters[6] = 0xD533a949740bb3306d119CC777fa900bA034cd52;
        hunters[7] = 0xdAC17F958D2ee523a2206206994597C13D831ec7;

        SMT.Leaf[] memory nextLeaves = new SMT.Leaf[](8);
        SMT.Leaf[] memory prevLeaves = new SMT.Leaf[](8);
        for (uint8 i = 0; i < hunters.length; i++) {
            nextLeaves[i] = SMT.Leaf({key: hunters[i], value: 1});
            prevLeaves[i] = SMT.Leaf({key: hunters[i], value: 0});
        }

        bytes32[] memory proof = new bytes32[](22);
        proof[0] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[1] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[2] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[3] = 0x000000000000000000000000000000000000000000000000000000000000004c;

        proof[4] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[5] = 0x0000000000000000000000000000000000000000000000000000000000000095;
        proof[6] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[7] = 0x0000000000000000000000000000000000000000000000000000000000000099;
        proof[8] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[9] = 0x000000000000000000000000000000000000000000000000000000000000009e;

        proof[10] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[11] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[12] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        proof[13] = 0x000000000000000000000000000000000000000000000000000000000000004c;
        
        proof[14] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[15] = 0x000000000000000000000000000000000000000000000000000000000000009b;
        proof[16] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[17] = 0x000000000000000000000000000000000000000000000000000000000000009c;
        proof[18] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[19] = 0x000000000000000000000000000000000000000000000000000000000000009e;
        proof[20] = 0x0000000000000000000000000000000000000000000000000000000000000048;
        proof[21] = 0x000000000000000000000000000000000000000000000000000000000000009f;

        root = SMT.update(proof, nextLeaves, prevLeaves, root);
    }

    function enter(bytes32[] memory _proofs) public {
        require(haveKey[msg.sender] == false);
        root = SMT.updateSingleTarget(_proofs, msg.sender, root, SMT.Method.Insert);
    }

    function leave(bytes32[] memory _proofs) public {
        require(haveTreasureChest[msg.sender] == false);
        root = SMT.updateSingleTarget(_proofs, msg.sender, root, SMT.Method.Delete);
    }

    function findKey(bytes32[] memory _proofs) public {
        require(smtMode == SMT.Mode.BlackList, "not blacklist mode");
        address[] memory targets = new address[](1);
        targets[0] = msg.sender;
        require(SMT.verifyByMode(_proofs, targets, root, smtMode), "hunter has fallen into a trap");
        haveKey[msg.sender] = true;
        smtMode = SMT.Mode.WhiteList;
        emit FindKey(msg.sender);
    }

    function pickupTreasureChest(bytes32[] memory _proofs) public {
        require(smtMode == SMT.Mode.WhiteList, "not whitelist mode");
        address[] memory targets = new address[](1);
        targets[0] = msg.sender;
        require(
            SMT.verifyByMode(_proofs, targets, root, smtMode),
            "hunter hasn't found the treasure chest"
        );
        haveTreasureChest[msg.sender] = true;
        smtMode = SMT.Mode.BlackList;
        emit PickupTreasureChest(msg.sender);
    }

    function openTreasureChest() public {
        require(haveKey[msg.sender] && haveTreasureChest[msg.sender]);
        solved = true;
        emit OpenTreasureChest(msg.sender);
    }

    function isSolved() public view returns (bool) {
        return solved;
    }
}

若想获取 flag 必须让 isSolved 返 true,必须调用 openTreasureChest,必须 haveKey 和 haveTreasureChest。获取箱子和钥匙的条件分别是证明当前地址存储的值为 1 和 0,解锁条件看起来是互斥的。

这个题目问题出现在验证上面,一般的默克尔证明是提供一组哈希值交给合约/服务端去验证,而这个题目提交的服务是一串代码,在题目定义的栈虚拟机中执行。虚拟机支持以下操作:

计算必须使用所有叶子节点,最后在栈中留下一个哈希值。验证时,比较这个值和 root 相同,否则拒绝。更新时,这个值成为新的 root

0x50 的操作基本上允许我们任意地合并元素。计划是这样:先调用 enter 和 pickupTreasureChest,这时根值为 calcLeaf(playerAddress, 1) 和其他叶子节点计算出的哈希值;在验证 (playerAddress, 0) 时,该叶子节点的哈希值会被算作 0,用原来的值与其合并,就可保持计算出的 root 不变,通过验证。

题解

smt.py
from Crypto.Hash import keccak

STACK, DEPTH = 32, 160

def keccak256(b):
    k = keccak.new(digest_bits=256)
    k.update(b)
    return k.digest()

def u256(b):
    return int.from_bytes(b, "big")

def tobytes(u, l):
    return bytes.fromhex(hex(u)[2:].rjust(2*l, "0"))

def b32(u):
    return tobytes(u, 32)

MODE_BLACK = 0
MODE_WHITE = 1

METHOD_INSERT = 0
METHOD_DELETE = 1

def calcLeaf(leaf: (int, int)) -> int:
    if leaf[1] == 0:
        return 0
    else:
        return u256(keccak256(b32(leaf[0]) + b32(leaf[1])))

def merge(left: int, right: int) -> int:
    if left == 0: return right
    elif right == 0: return left
    else: return u256(keccak256(b32(left) + b32(right)))

def verifyMode(proof: [int], targets: [int], root: int, mode: int) -> bool:
    leaves = list((target, mode) for target in targets)
    return verify(proof, leaves, root)

def verify(proof: [int], leaves: list, root: int):
    return calcRoot(proof, leaves, root) == root

def updateSingleTarget(proof: [int], target: int, prevRoot: int, method: int):
    return update(
        proof, [(target, method ^ 1)], [(target, method)], prevRoot
    )

def update(proof: [int], nextLeaves: list, prevLeaves: list, prevRoot: int):
    assert verify(proof, prevLeaves, prevRoot)
    return calcRoot(proof, nextLeaves, prevRoot)

def checkGroupSorted(leaves: list):
    for i, leave in enumerate(leaves[:-1]):
        if leave[0] > leaves[i+1][0]:
            return False
    return True

def getBit(key: int, height: int):
    return (key >> height) & 1

def copyBit(key: int, height: int):
    return (key >> height) << height

def parentPath(key, height):
    return copyBit(key, height+1)

PROOF_LEAF = 0x4c
PROOF_HASH = 0x50
PROOF_TREE = 0x48
def calcRoot(proof: [int], leaves: list, root: int):
    assert checkGroupSorted(leaves)
    keys = [0] * STACK
    values = [0] * STACK
    pi, li, st = 0, 0, 0

    while pi < len(proof):
        if proof[pi] == PROOF_LEAF:
            # push leaf to stack
            pi += 1
            assert st < STACK, "stack overrun"
            assert li < len(leaves), "leaves overrun"
            keys[st] = leaves[li][0]
            values[st] = calcLeaf(leaves[li])
            st += 1
            li += 1
        elif proof[pi] == PROOF_HASH:
            # supply proof (height, hashProof)
            pi += 1
            assert st > 0, "stack underrun"
            assert pi + 2 <= len(proof), "code overrun"
            height = proof[pi]
            pi += 1
            curProof = proof[pi]
            pi += 1
            assert curProof != root, "invalid proof equal to root"
            if getBit(keys[st-1], height) == 1:
                # stack value is on right site
                values[st-1] = merge(curProof, values[st-1])
            else:
                # stack value is on the left
                values[st-1] = merge(values[st-1], curProof)
            keys[st-1] = parentPath(keys[st-1], height)
        elif proof[pi] == PROOF_TREE:
            # merge two trees with one value each (height)
            pi += 1
            assert st >= 2 and pi < len(proof)
            height = proof[pi]
            pi += 1
            aset = getBit(keys[st-2], height)
            bset = getBit(keys[st-1], height)
            keys[st-2] = parentPath(keys[st-2], height)
            keys[st-1] = parentPath(keys[st-1], height)
            assert keys[st-2] == keys[st-1], "must have same parent"
            assert aset != bset, "must be different sibling"

            if aset == 1:
                values[st-2] = merge(values[st-1], values[st-2])
            else:
                values[st-2] = merge(values[st-2], values[st-1])
            st -= 1
        else:
            raise RuntimeError("Invalid proof")
    assert li == len(leaves) and st == 1
    return values[0]
exp.py
from smt import *
from web3 import Web3
from time import sleep

hunters = [0] * 8
hunters[0] = 0x0bc529c00C6401aEF6D220BE8C6Ea1667F6Ad93e
hunters[1] = 0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45
hunters[2] = 0x6B175474E89094C44Da98b954EedeAC495271d0F
hunters[3] = 0x6B3595068778DD592e39A122f4f5a5cF09C90fE2
hunters[4] = 0xAb5801a7D398351b8bE11C439e05C5B3259aeC9B
hunters[5] = 0xc00e94Cb662C3520282E6f5717214004A7f26888
hunters[6] = 0xD533a949740bb3306d119CC777fa900bA034cd52
hunters[7] = 0xdAC17F958D2ee523a2206206994597C13D831ec7
hunterZeroes = list((hun, 0) for hun in hunters)
hunterLeaves = list((hun, 1) for hun in hunters)
proof = [0x4c] * 4 + [0x48, 0x95, 0x48, 0x99, 0x48, 0x9e] + [0x4c] * 4 + [0x48, 0x9b, 0x48, 0x9c, 0x48, 0x9e, 0x48, 0x9f]

player = 0xf41f3c114e2Ca5B3760f168365a7b8040d795339
players = "0xf41f3c114e2Ca5B3760f168365a7b8040d795339"
playerKey = 0xd4a57343ad4077fa3e2144e2c004f45a22edcb751f06fa6df2190783cfcd3651
assert player > hunters[7]

root = 0
root = update(proof, hunterLeaves, hunterZeroes, root)
#print(hex(root))

haveKey = False
haveTreasure = False
mode = MODE_WHITE

def enter(proof: [int]):
    global root
    assert not haveKey
    root = updateSingleTarget(proof, player, root, METHOD_INSERT)

def leave(proof: [int]):
    global root
    assert not haveTreasure
    root = updateSingleTarget(proof, player, root, METHOD_DELETE)

def foundKey(proof: [int]):
    global root, mode, haveKey
    assert mode == MODE_BLACK
    assert verifyMode(proof, [player], root, mode)
    haveKey = True
    mode = MODE_WHITE
    print("found key")

def foundTreasure(proof: [int]):
    global root, mode, haveTreasure
    assert mode == MODE_WHITE
    assert verifyMode(proof, [player], root, mode)
    haveTreasure = True
    mode = MODE_BLACK
    print("found treasure")

def check():
    assert haveKey and haveTreasure
    print("win")

def genexpr(leaves: list, lower: int = 0, upper: int = 2 ** 256) -> (str, int):
    cnt = 0
    for i, leaf in enumerate(leaves):
        if lower <= leaf[0] < upper:
            cnt += 1
            last = i
    if cnt == 0:
        return "", 0
    elif cnt == 1:
        return f"#{last}", calcLeaf(leaves[last])
    else:
        l, left = genexpr(leaves, lower, (lower+upper)//2)
        r, right = genexpr(leaves, (lower+upper)//2, upper)

        if l and r:
            return f"({l},{r})", merge(left, right)
        elif l:
            return l, left
        elif r:
            return r, right
        else:
            return "", 0

#expr, hash = genexpr( sorted(hunterLeaves + [(player, 1)], key=lambda u: u[0]) )
#print(expr)
#print(bin(player))

l = list(calcLeaf(leaf) for leaf in hunterLeaves)
code = [
    PROOF_LEAF,
    PROOF_HASH, 3, merge(l[5], merge(l[6], l[7])),
    PROOF_HASH, 4, l[4],
    PROOF_HASH, 5, merge(l[0], merge(l[1], merge(l[2], l[3])))
]
print("[")
for x in code:
    print("[" + ",".join(str(y) for y in b32(x)) + "],")
print("]")

enter(code)
print(hex(root))
foundTreasure(code)

code1 = [
    PROOF_LEAF,
    PROOF_HASH, 0, calcLeaf((player, 1)),
] + code[1:]
print("[")
for x in code1:
    print("[" + ",".join(str(y) for y in b32(x)) + "],")
print("]")

foundKey(code1)
check()

arg0 = "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000004c00000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000003f4d48cacb338d80223fa2a9769ddfc803cc33d764ba4e5a0f5c304f2eb7cf5bc0000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000000406888f968192a70674eacf045568b8ea9498309e832d1afd30932de111b5de8100000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000005e9f810898db8dc62342eaa122fd26525362f2b70bd462edef6e4e34093d66c17"
arg1 = "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000004c00000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000000f557e52d06e9fd0fd25ca5551a32e28e950350615dac1085cd6b41903521322b00000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000003f4d48cacb338d80223fa2a9769ddfc803cc33d764ba4e5a0f5c304f2eb7cf5bc0000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000000406888f968192a70674eacf045568b8ea9498309e832d1afd30932de111b5de8100000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000005e9f810898db8dc62342eaa122fd26525362f2b70bd462edef6e4e34093d66c17"

def bytes4(s):
    return Web3.keccak(s.encode())[:4]

w3 = Web3(Web3.HTTPProvider("http://47.243.235.111:8545"))
contract = "your contract instance"

def sendtx(data):
    signed = w3.eth.account.sign_transaction(dict(
        nonce=w3.eth.get_transaction_count(players),
        gasPrice=2100000,
        gas=4000000,
        to=contract,
        value=0,
        data=data,
        chainId=w3.eth.chain_id,
    ), playerKey)
    print( w3.eth.send_raw_transaction(signed.rawTransaction) )

def pwn():
    c0 = bytes.fromhex(arg0)
    c1 = bytes.fromhex(arg1)
    sendtx(bytes4("enter(bytes32[])") + c0)
    sleep(20)
    sendtx(bytes4("pickupTreasureChest(bytes32[])") + c0)
    sleep(20)
    sendtx(bytes4("findKey(bytes32[])") + c1)
    sleep(20)
    sendtx(bytes4("openTreasureChest()"))
    sleep(20)
    print("check is solved")

if __name__ == "__main__":
    pwn()

为了省事,解题固定用那个数值很大的地址 0xf41f3c114e2Ca5B3760f168365a7b8040d795339,这样树的哈希顺序是固定的 ((#0,(#1,(#2,#3))),(#4,((#5,(#6,#7)),#8))),其中 #8 是自己的地址。高度参数 3、4、5 对应 0xf41f3c114e2Ca5B3760f168365a7b8040d795339 中是 1 的二进制位。


Newer: Shadowsocks Transparent Proxy

Older: Services

Back to: Listing G2R