import argparse
import codecs
import copy
import importlib
import io
import json
import os.path
import random
import re
import string
import struct
import sys

from dataclasses import dataclass
from difflib import SequenceMatcher
from io import StringIO
from math import ceil
from typing import Optional, Callable, Generator

# class to track sent bits and related metadata
@dataclass
class SendResult:
    bits: bytearray
    message_ends: list[int]

@dataclass
class TestCase:
    # Name
    name: str 
    grading_category: Optional[str] = None

    # Input data
    sent_messages: Optional[tuple[bytes, ...]] = None
    sent_messages_generator: Optional[Callable[..., tuple[bytes, ...]]] = None

    # How to corrupt data; returns distorted bits or list of subtests
    distort_function: Optional[Callable[[SendResult], list[tuple[str, bytes]] | bytes]] = None
    distort_description: Optional[str] = None

    # How to score test case as passing
    acceptable_missing: int = 0
    acceptable_size: Optional[int] = None
    acceptable_corrupted: int = 0

    def display(self, out_fh, verbose=False):
        out_fh.write(f"Test '{self.name}'\n")
        if self.sent_messages_generator:
            out_fh.write(f"Messages generated by {self.sent_messages_generator}\n")
        if self.sent_messages:
            display_messages(out_fh, self.sent_messages, message_format='hex', label='Sent message (in hex)',
                truncate=not verbose)
        if self.distort_description:
            out_fh.write(f"Then distorts the bits by {self.distort_description}\n")
        elif self.distort_function:
            out_fh.write(f"Then distorts the bits with {self.distort_function}\n")

    def __str__(self):
        out_fh = StringIO()
        self.display(out_fh)
        return out_fh.getvalue()

    def for_json(self):
        return {
            'name': self.name,
            'grading_category': self.grading_category,
            'message': self.__str__(),
        }

@dataclass
class CompareResult:
    received: tuple[bytes, ...]
    extra_count: int
    corrupted_count: int
    missing_count: int
    compare_text: str

@dataclass
class TestResult:
    test_case: TestCase
    subtest_name: Optional[str] = None

    send_result: Optional[SendResult] = None

    distorted_bits: bytes  = bytes() # of 0/1s

    compare_result: Optional[CompareResult] = None
    
    passed_missing: bool = False
    passed_size: bool = False
    passed_corrupted: bool = False
    passed_extra: bool = False

    @property
    def passed(self):
        return self.passed_missing and self.passed_size and self.passed_corrupted and self.passed_extra

    def display_short(self, out_fh):
        passed_string = ("PASSED" if self.passed else "FAILED")
        if self.subtest_name:
            out_fh.write(f"{passed_string} '{self.test_case.name}' ")
            out_fh.write(f"subtest '{self.subtest_name}'\n")
        else:
            out_fh.write(f"{passed_string} '{self.test_case.name}'\n")

    def display(self, out_fh, verbose=False):
        self.display_short(out_fh)
        display_messages(out_fh, self.test_case.sent_messages,
            message_format='hex', label='Sent messages (as hexadecimal):',
            truncate=not verbose)
        out_fh.write("Sender sent (displayed as hexadecimal):\n")
        display_send_result(out_fh, self.send_result, truncate=not verbose, bits_format='hex')
        if not self.passed_size:
            out_fh.write(f"-- ERROR: size ({len(self.send_result.bits)/8} bytes) > "
                         f"{self.test_case.acceptable_size}\n")
        if self.test_case.distort_function:
            out_fh.write("Which were distorted into bits (as hexadecimal):\n")
            display_raw_bits(out_fh, self.distorted_bits, format='hex', truncate=not verbose)
        if self.compare_result.compare_text == '':
            out_fh.write("Which were received successfully as the original messages.\n")
        else:
            if verbose:
                if len(self.compare_result.received) > 0:
                    display_messages(out_fh, self.compare_result.received,
                        message_format='hex', label='Received message (in hex)')
                else:
                    print("No messages were received.")
                out_fh.write("Which had the following differences from the original messages:\n")
            else:
                out_fh.write("Which was received with the following differences from\n"
                             "the original messages:\n") 
            out_fh.write(self.compare_result.compare_text + '\n')
        if not self.passed_missing:
            out_fh.write(f"-- ERROR: {self.compare_result.missing_count} missing messages > "
                         f"{self.test_case.acceptable_missing}\n")
        if not self.passed_corrupted:
            out_fh.write(f"-- ERROR: {self.compare_result.corrupted_count} corrupted messages > "
                         f"{self.test_case.acceptable_corrupted}\n")
        if not self.passed_extra:
            out_fh.write(f"-- ERROR: {self.compare_result.extra_count} extra messages > 0\n")
    
    def __str__(self):
        out_fh = StringIO()
        self.display(out_fh)
        return out_fh.getvalue()

    def for_json(self):
        return {
            'test_case': self.test_case.for_json(),
            'subtest_name': self.subtest_name,
            'passed': self.passed,
            'message': self.__str__(),
        }


random_seed = 0

class Channel:
    def __init__(self):
        self.got_bits = bytearray()

    def send_bits(self, the_bits):
        for the_bit in the_bits:
            if the_bit != 0 and the_bit != 1:
                raise ValueError('got non-bit {}'.format(the_bit))
            self.got_bits += bytes([the_bit])

def _range(i, j):
    if i + 1 == j:
        return f'message #{i}'
    else:
        return f'messages #{i} through #{j-1} (inclusive)'

def _format_message(m):
    if len(m) == 0:
        return '<empty>'
    else:
        return ''.join(list(map(lambda x: '{:02x} '.format(x), m)))

def is_ascii_printable(c):
    if c > 128:
        return False
    elif chr(c) == ' ':
        return True
    elif chr(c) in string.whitespace:
        return False
    elif chr(c) in string.printable:
        return True
    else:
        return False

def bytes_to_bits(the_bytes):
    result = bytearray()
    for a_byte in the_bytes:
        for bit_ascii in '{:08b}'.format(a_byte):
            result += bytes([0 if bit_ascii == '0' else 1])
    return result

def generate_bits(sender_class, sent_messages):
    channel = Channel()
    sender = sender_class(channel)
    message_bytes = 0
    message_count = 0
    message_end_locs = []
    for message in sent_messages:
        messgae = bytes(message)
        message_bytes += len(message)
        message_count += 1
        sender.send_message(message)
        message_end_locs.append(len(channel.got_bits))
    return SendResult(
        bits=channel.got_bits,
        message_ends=message_end_locs
    )

def display_raw_bits(out_fh, bits, format, truncate=False):
    if len(bits) == 0:
        out_fh.write(" <empty>\n")
    elif format == 'binary':
        for i in range(0, len(bits) if not truncate else min(len(bits), 64), 64):
            out_fh.write("  ")
            out_fh.write(''.join(map(str, bits[i:i+64])))
            out_fh.write("\n")
        if truncate and len(bits) > 64:
            out_fh.write(f"    [plus {len(bits)-64} more bits]\n")
    elif format == 'hex':
        for i in range(0, len(bits) if not truncate else min(len(bits), 64 * 4), 64 * 4):
            out_fh.write("  ")
            missing = 0
            for ii in range(i, min(i + 64 * 4, len(bits)), 4):
                the_bits = bytearray(bits[ii:ii+4])
                missing = 0
                while len(the_bits) < 4:
                    missing += 1
                    the_bits.append(0)
                value = the_bits[0] * 8 + the_bits[1] * 4 + the_bits[2] * 2 + the_bits[3]
                out_fh.write(hex(value)[2:])
            out_fh.write("\n")
            if truncate and len(bits) > 64 * 4:
                out_fh.write(f"    [plus {len(bits)-(64*4)} more bits]\n")
            if missing > 0:
               out_fh.write(f"    ({missing} bit{'s' if missing > 1 else ''} in last hexadecimal radit not actually sent)\n")
    elif format == 'ascii-with-c-escapes' or format == 'ascii':
        missing = 0
        while len(bits) % 8 != 0:
            missing += 1
            bits.append(0)
        bits_as_bytes = bytearray()
        in_line = 0
        out_fh.write("  ")
        for i in range(0, len(bits), 8):
            if in_line > 64:
                in_line = 0
                out_fh.write("\n  ")
                if truncate:
                    out_fh.write(f"  [plus {len(bits)-i} more bits]\n")
                    break
            cur_char = 0
            cur_char_bits = bits[i:i+8]
            missing = 0
            while len(cur_char_bits) < 8:
                missing += 1
                cur_char_bits.append(0)
            for x in cur_char_bits:
                cur_char *= 2
                cur_char += x
            if is_ascii_printable(cur_char):
                out_fh.write(chr(cur_char))
                in_line += 1
            else:
                out_fh.write(f'\\x{cur_char:02x}')
                in_line += 4
        out_fh.write("\n")
        if missing > 0:
           out_fh.write(f"    ({missing} bits in last byte not actually sent)\n")
    else:
        raise NotImplementedError

def display_send_result(out_fh, result, bits_format='binary', truncate=False):
    bits = result.bits
    display_raw_bits(out_fh, bits, bits_format, truncate=truncate)

def receive_bits(receiver_class, bits):
    received_messages = []
    # FIXME: annotate received messages with bit #s of input
    receiver = receiver_class(lambda x: received_messages.append(bytes(x)))
    for bit in bits:
        receiver.handle_bit_from_network(bit)
    return received_messages


def receive_and_compare(receiver_class, distorted_bits, sent_messages):
    received_messages = receive_bits(receiver_class, distorted_bits)

    # normalize messages to ensure difflib doesn't complain about them being unhashable
    sent_messages = list(map(bytes, sent_messages))
    received_messages = list(map(bytes, received_messages))
    matches = SequenceMatcher(a=sent_messages, b=received_messages, autojunk=False)
    compare_text = []
    extra_messages = 0
    corrupted_messages = 0
    missing_messages = 0
    missing_in_results = 0
    maximum_compare_text = 10
    for tag, i1, i2, j1, j2 in matches.get_opcodes():
        if i2 < 0:
            i2 = len(sent_messages) + i2
        if j2 < 0:
            j2 = len(received_messages) + j2
        if tag == 'equal':
            pass
        elif tag == 'delete':
            if len(compare_text) < maximum_compare_text:
                compare_text.append(f'missing input {_range(i1, i2)}')
            else:
                missing_in_results += 1
            missing_messages += (i2 - i1)
        elif tag == 'insert':
            if len(compare_text) < maximum_compare_text:
                compare_text.append(f'extra output {_range(j1, j2)}')
            else:
                missing_in_results += 1
            extra_messages += (j2 -j1)
        elif tag == 'replace':
            input_count = (i2 - i1)
            output_count = (j2 - j1)
            corrupted_messages += min(input_count, output_count)
            if input_count > output_count:
                missing_messages += input_count - output_count
            elif output_count > input_count:
                extra_messages += output_count - input_count
            if len(compare_text) < maximum_compare_text:
                compare_text.append(f'input {_range(i1, i2)} corrupted into output {_range(j1, j2)}')
                for offset in range(min(3, input_count)):
                    compare_text.append(
                                   f'  input message #{i1 + offset} (hexadecimal bytes):\n'
                                   f'    {_format_message(sent_messages[i1+offset])}\n'
                    )
                if input_count > 3:
                    compare_text.append(f'  + {input_count - 3} more')
                for offset in range(min(3, output_count)):
                    compare_text.append(
                                   f'  output message #{j1 + offset} (hexadecimal bytes):\n'
                                   f'    {_format_message(received_messages[j1+offset])}\n'
                    )
                if output_count > 3:
                    compare_text.append(f'  + {output_count - 3} more')
            else:
                missing_in_results += 1
        else:
            raise Exception(f'internal error --- unknown difflib tag {tag}')
    if missing_in_results > 0:
        compare_text.append(f'+ {missing_in_results} more messages, not shown')

    return CompareResult(
        received=tuple(received_messages),
        extra_count=extra_messages,
        corrupted_count=corrupted_messages,
        missing_count=missing_messages,
        compare_text='\n'.join(compare_text)
    )

def run_test_case(sender_class, receiver_class, test_case, subtest_filter=None):
    if test_case.sent_messages_generator:
        test_case.sent_messages = test_case.sent_messages_generator(
            sender_class=sender_class
        )
    send_result = generate_bits(sender_class, test_case.sent_messages)
    if test_case.distort_function:
        distorted_list = test_case.distort_function(send_result)
        if isinstance(distorted_list, bytes):
            distorted_list = [(None, distorted_list)]
    else:
        distorted_list = [(None, send_result.bits)]
    for name, distorted_bits in distorted_list:
        if subtest_filter:
            if not subtest_filter(name):
                continue
        compare_result = receive_and_compare(receiver_class, distorted_bits, test_case.sent_messages)
        yield TestResult(
            test_case=test_case,
            subtest_name=name,
            send_result=send_result,
            distorted_bits=distorted_bits,
            compare_result=compare_result,
            passed_missing=compare_result.missing_count <= test_case.acceptable_missing,
            passed_size=test_case.acceptable_size == None or len(send_result.bits) <= test_case.acceptable_size * 8,
            passed_corrupted=compare_result.corrupted_count <= test_case.acceptable_corrupted,
            passed_extra=compare_result.extra_count == 0
        )

def run_test_cases(sender_class, receiver_class, all_tests, subtest_filter=None):
    for test_case in all_tests:
        for result in run_test_case(sender_class, receiver_class, test_case, subtest_filter=subtest_filter):
            yield result

def get_rng():
    global random_seed
    rng = random.Random(random_seed)
    return rng

def do_corrupt_random(
        flip_rate: float, add_rate: float, delete_rate: float,
        flip_count: int, add_count: int, delete_count: int,
        corrupt_limit_messages: int,
        trials: int,
        send_result: SendResult):
    old_bits = send_result.bits
    message_ends = send_result.message_ends
    corrupted = []
    for trial_num in range(trials):
        assert len(old_bits) > 0, "sender produced no bytes?"
        new_bits = bytearray(old_bits)
        rng = get_rng()
        if corrupt_limit_messages == None:
            corrupt_limit = len(old_bits)
        else:
            corrupt_limit = message_ends[corrupt_limit_messages - 1]
        if len(old_bits) > 0 and (delete_rate > 0 or delete_count > 0):
            delete_count += ceil(corrupt_limit * delete_rate)
            delete_points = sorted(rng.sample(range(corrupt_limit), delete_count))
            new_bits = bytearray(len(old_bits) - len(delete_points))
            out_loc = 0
            in_loc = 0
            for i in delete_points:
                count = i - in_loc
                new_bits[out_loc:out_loc + count- 1] = old_bits[in_loc:in_loc + count - 1]
                in_loc += count
                in_loc += 1
                out_loc += count
            new_bits[out_loc:] = old_bits[in_loc:]
            old_bits = new_bits
            corrupt_limit = min(corrupt_limit, len(old_bits))
        if len(old_bits) > 0 and (flip_rate > 0 or flip_count > 0):
            flip_count += ceil(corrupt_limit * flip_rate)
            new_bits = bytearray(old_bits)
            flip_points = sorted(rng.sample(range(corrupt_limit), flip_count))
            for i in flip_points:
                new_bits[i] ^= 1
            old_bits = new_bits
        if len(old_bits) > 0 and (add_rate > 0 or add_count > 0):
            add_count += ceil(corrupt_limit * add_rate)
            add_points = sorted(rng.sample(range(corrupt_limit), add_count))
            new_bits = bytearray(len(old_bits) + len(add_points))
            out_loc = 0
            in_loc = 0
            for i in add_points:
                count = i - in_loc
                new_bits[out_loc:out_loc + count] = old_bits[in_loc:in_loc + count]
                in_loc += count
                out_loc += count
                new_bits[out_loc] = rng.randrange(0, 2)
                out_loc += 1
            new_bits[out_loc:] = old_bits[in_loc:]
            old_bits = new_bits
        corrupted.append((f'trial #{trial_num}', new_bits))
    return corrupted

def make_corrupt_random(flip_rate=0, add_rate=0, delete_rate=0, flip_count=0, add_count=0, delete_count=0,
                        corrupt_limit_messages=None, trails=3):
    """Generate a function to corrupt a message. Parameters:

    *  flip_rate, flip_count: portion/number of bits to flip a bit (at random)
    *  add_rate, add_count: portion/number of bits to add an additional bit to (at random)
    *  delete_rate, delete_count: portion/number of bytes to delete (at random)
    """
    result = lambda send_result: do_corrupt_random(
          send_result=send_result, flip_rate=flip_rate, flip_count=flip_count,
          add_rate=add_rate, delete_rate=delete_rate,
          add_count=add_count, delete_count=delete_count,
          corrupt_limit_messages=corrupt_limit_messages,
          trials=trails
    )
    result.__name__ = 'corrupt ({trails} trials; change bit {flip_rate * 100}% + {flip_count}, add bit {add_rate * 100}% + {add_count}, delete bit {delete_rate * 100}% + {delete_count})' + ('--- but only first {corrupt_limit_portion * 100}% of data')
    return result

def do_corrupt_each(send_result: SendResult, message_limit: int, maximum_indices: int = 512, first_message: int = 1):
    """
    Systematically corrupt each bit up to message # message_limit provided
    this is less than maximum_indices bits. If it would be more than maximum_indices
    bits, that many bits are selected rnadomly.
    """
    results = []
    raw_bits = send_result.bits
    first_bit_index = 0
    if first_message > 1:
        first_bit_index = send_result.message_ends[first_message - 2]
    limit_bit_index = send_result.message_ends[message_limit - 1]
    assert first_bit_index < limit_bit_index, f'{first_bit_index} to {limit_bit_index}'
    if limit_bit_index - first_bit_index > maximum_indices:
        indices = get_rng().sample(range(first_bit_index, limit_bit_index), k=maximum_indices)
    else:
        indices = list(range(first_bit_index, limit_bit_index))
    for i in indices:
        results.append((
            f'flip bit #{i}',
            bytes(raw_bits[0:i] + bytes([raw_bits[i]^1]) + raw_bits[i+1:])
        ))
        results.append((
            f'add zero after #{i}',
            bytes(raw_bits[0:i+1] + bytes([0]) + raw_bits[i+1:])
        ))
        results.append((
            f'add one after #{i}',
            bytes(raw_bits[0:i+1] + bytes([1]) + raw_bits[i+1:])
        ))
        results.append((
            f'delete bit #{i}',
            bytes(raw_bits[0:i] + raw_bits[i+1:])
        ))
    return results

def make_corrupt_each(message_limit, first_message=1):
    result = lambda send_result: do_corrupt_each(send_result, message_limit=message_limit, first_message=first_message)
    result.__name__ = f'flip/add/delete each bit up in messages {first_message} to {message_limit}'
    return result

def _bits_to_bytes(raw_bits):
    current = 0
    result = bytearray()
    for i, x in enumerate(raw_bits):
        current *= 2
        current += x
        if i % 8 == 0:
            result += bytes([current])
            current = 0
    if len(raw_bits) % 8 != 0:
        result += bytes([current])
    return bytes(result)

def _do_quote_previous_messages(results, previous_test, locations, before, after):
    raw_bits = results[previous_test]['original_bits']
    messages = []
    if before:
        messages += before
    for location in locations:
        messages.append(_bits_to_bytes(raw_bits[location[0]:location[1]]))
    if after:
        messages += after
    return messages

def make_messages_using_sent_bytes(previous_test, locations, before=None, after=None):
    return lambda results: _do_quote_previous_messages(results, previous_test=previous_test, locations=locations, before=before, after=after)

class ReuseBytesGenerator(object):
    def __init__(self, messages, locations, before=tuple(), after=tuple()):
        self._messages = messages
        self._locations = locations
        self._before = before
        self._after = after

    def __str__(self):
        # FIXME: more descriptive?
        message_as_hex = map(lambda x: x.hex() if len(x) > 0 else '<empty>', self._messages)
        return "generate messages for {} and send new messages containing parts of those"

    def __call__(self, sender_class):
        raw_bits = generate_bits(sender_class, self._messages).bits
        messages = []
        messages.extend(self._before)
        for location in self._locations:
            messages.append(_bits_to_bytes(raw_bits[location[0]:location[1]]))
        messages.extend(self._after)
        return messages

TESTS = [
    TestCase(
        name='empty-clean',
        sent_messages=(b'', b'', b''),
        acceptable_size=120*3,
        grading_category='simple-short',
    ),
    TestCase(
        name='tiny-clean1',
        sent_messages=(b'A', b'B', b'D'),
        acceptable_size=120*3,
        grading_category='simple-short',
    ),
    TestCase(
        name='tiny-clean2',
        sent_messages=(b'ABC', b'C', b'DE'),
        acceptable_size=120*3,
        grading_category='simple-short',
    ),
    TestCase(
        name='tiny-clean3',
        sent_messages=(b'\x00', b'\xFE', b'\x01', b'\x80', b'\xFF'),
        acceptable_size=120*3,
        grading_category='binary-short',
    ),
    TestCase(
        name='tiny-clean4',
        sent_messages=(b'\x00\x00\x00\x00', b'\xFF\xFF\xFE\xFF\xFF'),
        acceptable_size=120*3,
        grading_category='binary-short',
    ),
    TestCase(
        name='three-message-clean',
        sent_messages=tuple([b'message 1', b'message 2', b'message 3']),
        acceptable_size=120 * 3,
        grading_category='simple-short',
    ),
    TestCase(
        name='three-message-binary-clean',
        sent_messages=(
            b'\x00\x01\x02\x03message 1' + bytes(range(256)),
            b'\x0a\x0d\x00\x0d\x0amessage 2' + bytes(range(128, 256)) + bytes(range(0, 128)),
            b'\x9a\x00message 1' + bytes(range(128, 256)) + bytes(range(1, 128, 5)),
        ),
        acceptable_size=int(270 * 3 * 1.21),
        grading_category='binary-long',
    ),
    TestCase(
        name='three-message-long-clean',
        sent_messages=(
            b'message 1' + (b'X' * 1000),
            b'message 2' + (b'Y' * 999),
            b'message 3' + (b'XY' * 501),
        ),
        acceptable_size=int((1020 + 1020 + 520) * 1.21),
        grading_category='simple-long',
    ),
    TestCase(
        name='many-message-clean',
        sent_messages=tuple(
            map(lambda x: f'{x}'.encode('us-ascii'), range(1000))
        ),
        acceptable_size=120*1000,
        grading_category='simple-long',
    ),
    TestCase(
        name='many-message-clean-varlen',
        sent_messages=tuple(
            map(lambda x: f'{x}{"Z" * ((x*137)%131)}'.encode('us-ascii'), range(1000))
        ),
        acceptable_size=120*1000,
        grading_category='simple-long',
    ),
    TestCase(
        name='with-empty-message-clean',
        sent_messages=(b'before', b'', b'after'),
        acceptable_size=120*3,
        grading_category='simple-short',
    ),
    TestCase(
        name='send-empty-message-bits-in-message1',
        sent_messages_generator=ReuseBytesGenerator(
            messages=(b'',b''),
            locations=[(0, 200)],
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-empty-message-bits-in-message2',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'',b''],
            locations=[(1, 10* 8)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-empty-message-bits-in-message3',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'A',b'B'],
            locations=[(2, 10* 8)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-empty-message-bits-in-message4',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'',b''],
            locations=[(8, -1)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-empty-message-bits-in-message5',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'',b''],
            locations=[(10 * 8,  20* 8), (0 * 8, 10 * 8), (0, 5 * 8), (0, -1)]
        ),
        grading_category='binary-long',
    ),
    TestCase(
        name='send-non-empty-message-bits-in-message1',
        sent_messages_generator=ReuseBytesGenerator(
            messages=(b'A',b'Bb'),
            locations=[(0, 200)],
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-non-empty-message-bits-in-message2',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'A',b'Bb'],
            locations=[(1, 10* 8)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-non-empty-message-bits-in-message3',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'A',b'Bb'],
            locations=[(2, 10* 8)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-non-empty-message-bits-in-message4',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'A',b'Bb'],
            locations=[(8, -1)]
        ),
        grading_category='binary-short',
    ),
    TestCase(
        name='send-non-empty-message-bits-in-message5',
        sent_messages_generator=ReuseBytesGenerator(
            messages=[b'A',b'Bb'],
            locations=[(10 * 8,  20* 8), (0 * 8, 10 * 8), (0, 5 * 8), (0, -1)]
        ),
        grading_category='binary-long',
    ),
    TestCase(
        name='corrupt-short1',
        sent_messages=(b'xy', b'yz'),
        distort_function=make_corrupt_each(message_limit=2),
        acceptable_missing=4,
        acceptable_corrupted=0,
        acceptable_size= 120 * 3,
        grading_category='corrupt',
    ),
    TestCase(
        name='corrupt-short2',
        sent_messages=(b'\x00', b'\xFF',b'\xFE'),
        distort_function=make_corrupt_each(message_limit=2),
        acceptable_missing=3,
        acceptable_corrupted=0,
        acceptable_size= 120 * 3,
        grading_category='corrupt-binary',
    ),
    TestCase(
        name='corrupt-first-recovery1',
        sent_messages=(b'x', b'y', b'Z' * 1000, b'Q' * 1000,
                       b'A', b'B', b'C', b'D', b'E', b'F', b'G'),
        distort_function=make_corrupt_each(message_limit=1),
        acceptable_missing=4,
        acceptable_corrupted=0, # probability of checksum not working should be negligible for such a small message
        acceptable_size=120 * 9 + 1201 * 2,
        grading_category='corrupt',
    ),
    TestCase(
        name='corrupt-first-recovery2',
        sent_messages=(b'\x00', b'\x01', b'\xFE' * 1000, b'\xFF' * 1000,
                       b'\xFF', b'\xFE', b'\xFD', b'\x00', b'\x00\x01', b'\x02\x01\x00', b'\x03\x01\x00\x02'),
        distort_function=make_corrupt_each(message_limit=1),
        acceptable_missing=4,
        acceptable_corrupted=0, # probability of checksum not working should be negligible for such a small message
        acceptable_size=120 * 9 + 1201 * 2,
        grading_category='corrupt-binary',
    ),
    TestCase(
        name='corrupt-long-recovery',
        sent_messages=(b'A' * 1000, b'X' * 1000, b'Z' * 1000, b'Q' * 150, b'R' * 150,),
        distort_function=make_corrupt_each(message_limit=2, first_message=2),
        acceptable_missing=2,
        acceptable_corrupted=0, # probability of checksum not working should be negligible
        acceptable_size=int(1024 * 1.21 * 5),
        grading_category='corrupt',
    ),
    TestCase(
        name='corrupt-first-recovery-tiny2',
        sent_messages=(b'\xFE', b'\xFE', b'\x00' * 1000, b'\x00' * 1000,
                       b'A', b'B', b'C', b'D', b'E', b'F', b'G'),
        distort_function=make_corrupt_each(message_limit=1),
        acceptable_missing=4,
        acceptable_corrupted=0, # probability of checksum not working should be negligible for such a small message
        acceptable_size=120 * 9 + 1201 * 2,
        grading_category='corrupt-binary',
    ),
    TestCase(
        name='allbytes-clean1',
        sent_messages=tuple([bytes([x]) for x in range(256)]),
        acceptable_size=120 * 256,
        grading_category='binary-long',
    ),
    TestCase(
        name='allbytes-clean2',
        sent_messages=tuple([bytes([x] * 60) for x in range(256)]),
        acceptable_size=120 * 256,
        grading_category='binary-long',
    ),
    TestCase(
        name='many-message-corrupt',
        distort_function=make_corrupt_random(flip_count=20),
        sent_messages=tuple(map(lambda x: f'this is message {x} XXXXX XXXXX XXXXX'.encode('us-ascii'), range(10000))),
        # With this rate of corruption, we'd estimate around 20 messages corrupted.
        # Assume at most 25 extra messages are corrupted due to desync
        acceptable_missing=20 * 25,
        acceptable_corrupted=1,
        acceptable_size=120*10000,
        grading_category='corrupt',
    ),
    TestCase(
        name='many-message-corrupt-varlen',
        distort_function=make_corrupt_random(add_count=10, delete_count=10),
        sent_messages=tuple(map(lambda x: struct.pack("H", x) + (bytes([x % 255]) * (1+(x*137)%7)), range(10000))),
        # guess around 20 messages corrupted, and each corrupted message causing 200 mesages to be missed due to
        # desync at most
        acceptable_missing=4000,
        acceptable_corrupted=1,
        acceptable_size=120*10000,
        grading_category='corrupt-binary',
    ),
    TestCase(
        name='many-message-corrupt-recover1',
        distort_function=make_corrupt_random(flip_rate=0.0001, add_rate=0.0001, delete_rate=0.0001, corrupt_limit_messages=100),
        sent_messages=tuple(map(lambda x: struct.pack("H", x) + (b"Q" * (1+(x*137)%7)), range(1000))),
        # first 100 badly corrupted, should recover within a couple kilobytes of messages, which pessimistically might be 300 messages
        acceptable_missing=100 + 300,
        acceptable_corrupted=1,
        acceptable_size=120 * 1000,
        grading_category='corrupt-binary',
    ),
    TestCase(
        name='many-message-corrupt-recover2',
        distort_function=make_corrupt_random(flip_rate=0.0001, add_rate=0.0001, delete_rate=0.0001, corrupt_limit_messages=100),
        sent_messages=tuple(map(lambda x: struct.pack("H", x) +  (bytes([x % 255]) * (1+(x*137)%7)), range(10000))),
        # first 100 badly corrupted, should recover within a couple kilobytes of messages, which pessimistically might be 300 messages
        acceptable_missing=100 + 300,
        acceptable_corrupted=1,
        acceptable_size=120 * 10000,
        grading_category='corrupt-binary',
    ),
]

def setup_argparser():
    common_parser = argparse.ArgumentParser(add_help=False, argument_default=argparse.SUPPRESS)
    common_parser.add_argument('--test-module',
        help='python module to test (default: sendrecv, meaning sendrecv.py)',
        default='sendrecv')
    common_parser.add_argument('--sender-class',
        help='sender class to test (default: MySender)',
        default='MySender')
    common_parser.add_argument('--receiver-class', 
        help='receiver class to test (defualt: MyReceiver)',
        default='MyReceiver')
    common_parser.add_argument('--verbose', action='store_true',
        help='enable verbose output',
        default=False)
    common_parser.add_argument('--quiet', action='store_true',
        help='do not output about individual passed tests',
        default=False)

    parser = argparse.ArgumentParser()
    # set defaults here so subcommand defaults don't override global options

    subparsers = parser.add_subparsers(required=True)
    send_parser = subparsers.add_parser('send',
        help='Send a set of messages and display the resulting bits.',
        parents=[common_parser])

    send_parser.add_argument('-i', '--input-format', default='ascii',
        help='Format of input messages (default: ascii)',
        choices=['ascii', 'binary', 'hex'])

    send_parser.add_argument('-o', '--output-format', default='binary',
        help='Format of output messages (default: binary)',
        choices=['binary', 'hex'])

    send_parser.add_argument('messages', nargs='+',
        help='Messages to send.')

    send_parser.set_defaults(command='send')

    recv_parser = subparsers.add_parser('recv',
        help='Receive a set of messages and display the resulting bytes.',
        parents=[common_parser])

    recv_parser.add_argument('-i', '--input-format', default='ascii',
        help='Format of input messages (default: binary)',
        choices=['ascii', 'binary', 'hex'])

    recv_parser.add_argument('-o', '--output-format', default='ascii',
        help='Format of output messages (default: ascii)',
        choices=['ascii', 'binary', 'hex'])

    recv_parser.add_argument('data', type=str,
        help='Data to send.')

    recv_parser.set_defaults(command='recv')

    test_suite_parser = subparsers.add_parser('suite',
        help='Run the test-suite specified by TESTS in test.py'
             ' (possibly with some filtering)',
        parents=[common_parser])
    test_suite_parser.set_defaults(command='suite')

    test_suite_parser.add_argument('--random-seed', default=42, type=int,
        help='random seed to use for random parts of tests')
    test_suite_parser.add_argument('--keep-going', default=False, action='store_true',
        help='keep going after first failure')
    test_suite_parser.add_argument('--only-test', default=None, type=str,
        help='only run tests matching a specified regular expression')
    test_suite_parser.add_argument('--only-subtest', default=None, type=str,
        help='only run subtests matching a specified regular expression')
    test_suite_parser.add_argument('--ignore-too-many-bits',
        default=False, action='store_true',
        help='ignore errors from too many bits')
    test_suite_parser.add_argument('--json', default=False, action='store_true',
        help='JSON-format output (for grading)')
    test_suite_parser.add_argument('--json-file', help='JSON output file (for grading)')
    test_suite_parser.add_argument('--show-subtests', default=False, action='store_true',
        help='list subtests individually in results')

    return parser

def get_sender_class_from(args):
    module = importlib.import_module(args.test_module)
    return module.__dict__[args.sender_class]
    
def get_receiver_class_from(args):
    module = importlib.import_module(args.test_module)
    return module.__dict__[args.receiver_class]

def display_subtest_result(args, result):
    if not args.json:
        if args.verbose or not result.passed:
            sys.stdout.write('\n')
            result.display(sys.stdout, verbose=args.verbose)
        elif not args.quiet:
            sys.stdout.write('\n')
            result.display_short(sys.stdout)
        else:
            pass

def display_subtest_result_list(args, results):
    if args.json:
        return
    if len(results) == 0:
        return
    if len(results) == 1:
        display_subtest_result(args, results[0])
        return
    base_name = results[0].test_case.name
    passed_count = 0
    failed = []
    for result in results:
        if result.passed:
            passed_count += 1
        else:
            failed.append(result)
    if len(failed) == 0:
        print(f"\nPASSED '{base_name}' ({passed_count} subtests)")
    else:
        print(f"\n'{base_name}': PASSED {passed_count} subtests of {len(results)} run")
        for result in failed:
            display_subtest_result(args, result)


def run_suite(args):
    global random_seed
    random_seed = args.random_seed
    sender_class = get_sender_class_from(args)
    receiver_class = get_receiver_class_from(args)
   
    global TESTS 
    if args.only_test:
        _is_test_okay = lambda test: re.fullmatch(args.only_test, test.name) != None
        filtered_tests = list(filter(_is_test_okay, TESTS))
    else:
        filtered_tests = TESTS
    if args.ignore_too_many_bits:
        new_filtered_tests = []
        for test in filtered_tests:
            new_test = copy.copy(test)
            new_test.acceptable_size = None
            new_filtered_tests.append(new_test)
        filtered_tests = new_filtered_tests
    all_results = []
    # FIXME: 
    if args.only_subtest:
        subtest_filter = lambda subtest_name: re.fullmatch(args.only_subtest, subtest_name) != None
    else:
        subtest_filter = None
    current_test_results = []
    stopped_early = False
    test_names = set()
    result_count = 0
    for test_result in run_test_cases(sender_class, receiver_class, filtered_tests, subtest_filter):
        result_count += 1
        all_results.append(test_result)
        test_name = test_result.test_case.name
        test_names.add(test_name)
        passed = test_result.passed
        if args.show_subtests or args.verbose:
            display_subtest_result(args, test_result)
        else:
            if len(current_test_results) > 0 and current_test_results[-1].test_case.name != test_name:
                display_subtest_result_list(args, current_test_results)
                current_test_results = []
            current_test_results.append(test_result)
        if not args.keep_going and not passed:
            stopped_early = True
            break
    display_subtest_result_list(args, current_test_results)
    if stopped_early:
        print("*** stopped because of failed test (use --keep-going to not stop)")
    else:
        print(f"*** done; passed {len(test_names)} tests ({result_count} subtests)")
    if args.json:
        if args.json_file:
            if os.path.exists(args.json_file):
                print(f"ERROR: {args.json_file} already exists.", file=sys.stderr)
                sys.exit(1)
            else:
                output_file = open(args.json_file, 'w')
        else:
            output_file = sys.stdout
        json.dump(list(map(
            lambda x: x.for_json(), all_results
        )), output_file, sort_keys=True, indent=2)
        if args.json_file:
            output_file.close()

def encode_bits_from_text(format_name, text):
    if format_name == 'ascii':
        return bytes_to_bits(text.encode('us-ascii'))
    elif format_name == 'binary':
        result = bytearray()
        if re.search('[^01]', text) != None:
            raise ValueError(f"input '{text}' contains characters not valid in binary")
        for x in text:
            result.append(int(x))
        return result
    elif format_name == 'hex':
        result = bytearray()
        if text.startswith('0x') or text.startswith('0X'):
            text = text[2:]
        if re.search('[^0-9a-fA-F]', text) != None:
            raise ValueError(f"input '{text}' contains characters not valid in hexadecimal")
        for x in text:
            cur_number = int(x, 16)
            for b in (8,4,2,1):
                result.append(1 if cur_number & b != 0 else 0)
        return result

def encode_message_from_text(format_name, text):
    if format_name == 'ascii':
        return text.encode('us-ascii')
    elif format_name == 'binary':
        if text.startswith('0b') or text.startswith('0B'):
            text = text[2:]
        if re.search(r'[^01]', text) != None:
            raise ValueError(f"input '{text}' contains characters not valid in binary")
        if len(text) % 8 != 0:
            raise ValueError(f"input '{text}' length is not a multiple of 8; cannot convert to bytes as binary")
        result = b''
        for i in range(len(text) // 8):
            this_byte = text[8*i:8*i+8]
            result += bytes([int(this_byte, 2)])
        return result
    elif format_name == 'hex':
        if text.startswith('0x') or text.startswith('0X'):
            text = text[2:]
        if re.search(r'[^0-9a-fA-F]', text) != None:
            raise ValueError(f"input '{text}' contains characters not valid in hexadecimal")
        if len(text) % 2 != 0:
            raise ValueError(f"input '{text}' length is not a multiple of 2; cannot convert to bytes as hexadecimal")
        result = bytearray()
        for i in range(len(text) // 2):
            this_byte = text[2*i:2*i+2]
            result.append(int(this_byte, 16))
        return result

def display_messages(out_fh, messages, message_format='binary', label="Sent message", nothing_message=None, truncate=False):
    if len(messages) == 0:
        if nothing_message == None:
            nothing_message = 'No ' + label.lower() + 's.'
        print(nothing_message, file=out_fh)
    else:
        display_count = (len(messages) if not truncate else 3)
        for index, message in enumerate(messages[0:display_count]):
            print(f"{label} #{index+1}:", file=out_fh)
            display_raw_bits(out_fh, bytes_to_bits(message), message_format, truncate=truncate)
        if display_count < len(messages):
            print(f"+ {len(messages) - display_count} additional messages", file=out_fh)

def run_send(args):
    messages_as_bytes = list(map(lambda text: encode_message_from_text(args.input_format, text), args.messages))
    display_messages(sys.stdout, messages_as_bytes, message_format=args.output_format)
    result = generate_bits(get_sender_class_from(args), messages_as_bytes)
    print("Sent as:")
    display_send_result(sys.stdout, result, args.output_format)

def run_recv(args):
    bits_to_send = encode_bits_from_text(args.input_format, args.data)
    print("Receiving bits:")
    display_raw_bits(sys.stdout, bits_to_send, args.input_format)
    received_messages = receive_bits(get_receiver_class_from(args), bits_to_send)
    display_messages(sys.stdout, received_messages, message_format=args.output_format, label="Received message") 

def main(argv):
    parser = setup_argparser()
    args = parser.parse_args(argv)
    if args.command == 'suite':
        run_suite(args)
    elif args.command == 'send':
        run_send(args)
    elif args.command == 'recv':
        run_recv(args)
    else:
        raise RuntimeError(f'unimplemented command {args.command}')
    

if __name__ == '__main__':
    main(sys.argv[1:])
