from torch.utils.data import communication


class Protocol(object):
    __slots__ = ('request_queue', 'response_queue')

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue


class ProtocolClient(Protocol):
    """
        ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue.
    """
    _req_sent = None

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue
        self._req_sent = None

    def can_take_request(self):
        return self._req_sent is None

    def waiting_for_response(self):
        return self._req_sent is not None

    def request_sent(self, request=True):
        if not self.can_take_request():
            raise Exception('Protocol only supports one request in the Queue')
        self._req_sent = request

    def request_served(self, result=None):
        if not self.waiting_for_response():
            raise Exception(
                'Expected no peding requests, but something got served', result)
        self._req_sent = None


class ProtocolServer(Protocol):
    """
        ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe.
    """
    _req_received = None

    def __init__(self, request_queue, response_queue):
        self.request_queue = request_queue
        self.response_queue = response_queue
        self._req_received = None

    def have_pending_request(self):
        return self._req_received is not None

    def get_new_request(self, block=False):
        if self.have_pending_request():
            raise Exception(
                'Trying to get next request, while having one unserved')
        try:
            response = self.request_queue.get(block=block)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self._req_received = response
        return response
        # TODO: Validate supported requests

    def response_terminate(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        if not isinstance(self._req_received, communication.messages.TerminateRequest):
            raise Exception(
                "Replaying with terminate status to other type of message")
        self.response_queue.put(communication.messages.TerminateResponse())
        self._req_received = None


class MapDataPipeQueueProtocolServer(ProtocolServer):
    def response_item(self, key, value):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.GetItemResponse(key, value))
        self._req_received = None

    def response_len(self, size):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.LenResponse(size))
        self._req_received = None

    def response_index_out_of_bound(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.StopIterationResponse())
        self._req_received = None

class MapDataPipeQueueProtocolClient(ProtocolClient):
    def request_len(self):
        if not self.can_take_request():
            raise Exception('Can not request len while we are still waiting response for previous request')
        request = communication.messages.LenRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def request_item(self, index):
        if not self.can_take_request():
            raise Exception('Can not request item while we are still waiting response for previous request')
        request = communication.messages.GetItemRequest(index)
        self.request_queue.put(request)
        self.request_sent(request)

    def get_response_len(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception('Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except TimeoutError:
            raise EmptyQueue('queue is empty')
        self.request_served(response)
        if not isinstance(response, communication.messages.LenResponse):
            raise Exception('Invalid response received')
        return response

    def get_response_item(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception('Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except TimeoutError:
            raise EmptyQueue('queue is empty')
        self.request_served(response)
        # if not isinstance(response, communication.messages.GetItemResponse):
        #     raise Exception('Invalid response received')
        return response


class EmptyQueue(Exception):
    pass


class IterDataPipeQueueProtocolServer(ProtocolServer):
    def response_reset_iterator(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        if not isinstance(self._req_received, communication.messages.ResetIteratorRequest):
            raise Exception(
                "Replaying with reset status to other type of message")
        self.response_queue.put(communication.messages.ResetIteratorResponse())
        self._req_received = None

    def response_next(self, value):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.GetNextResponse(value))
        self._req_received = None

    def response_stop_iteration(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.StopIterationResponse())
        self._req_received = None

    def response_invalid_state(self):
        if not self.have_pending_request():
            raise Exception("Attempting to reply with pending request")
        self.response_queue.put(communication.messages.InvalidStateResponse())
        self._req_received = None


class IterDataPipeQueueProtocolClient(ProtocolClient):
    def request_reset_iterator(self):
        if not self.can_take_request():
            raise Exception('Can not reset while we are still waiting response for previous request')
        request = communication.messages.ResetIteratorRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def request_next(self):
        if not self.can_take_request():
            raise Exception('Can not request next item while we are still waiting response for previous request')
        request = communication.messages.GetNextRequest()
        self.request_queue.put(request)
        self.request_sent(request)

    def get_response_reset_iterator(self, block=False):
        try:
            response = self.response_queue.get(block=block)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self.request_served(response)

        if not isinstance(response, communication.messages.ResetIteratorResponse):
            raise Exception('Invalid response received')

    def get_response_next(self, block=False, timeout=None):
        if not self.waiting_for_response():
            raise Exception(
                'Can not expect any response without submitted request')
        try:
            response = self.response_queue.get(block=block, timeout=timeout)
        except Exception as e:  # TODO: Catch only timeout exceptions
            raise EmptyQueue('queue is empty')
        self.request_served(response)

        # TODO(VitalyFedyunin): Add possible response types validation here
        return response
