from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
from tornado.iostream import (
    IOStream,
    SSLIOStream,
    PipeIOStream,
    StreamClosedError,
    _StreamBuffer,
)
from tornado.httputil import HTTPHeaders
from tornado.locks import Condition, Event
from tornado.log import gen_log
from tornado.netutil import ssl_wrap_socket
from tornado.tcpserver import TCPServer
from tornado.testing import (
    AsyncHTTPTestCase,
    AsyncHTTPSTestCase,
    AsyncTestCase,
    bind_unused_port,
    ExpectLog,
    gen_test,
)
from tornado.test.util import skipIfNonUnix, refusing_port, skipPypy3V58
from tornado.web import RequestHandler, Application
import errno
import hashlib
import os
import platform
import random
import socket
import ssl
import sys
from unittest import mock
import unittest


def _server_ssl_options():
    return dict(
        certfile=os.path.join(os.path.dirname(__file__), "test.crt"),
        keyfile=os.path.join(os.path.dirname(__file__), "test.key"),
    )


class HelloHandler(RequestHandler):
    def get(self):
        self.write("Hello")


class TestIOStreamWebMixin(object):
    def _make_client_iostream(self):
        raise NotImplementedError()

    def get_app(self):
        return Application([("/", HelloHandler)])

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    @gen_test
    def test_read_until_close(self):
        stream = self._make_client_iostream()
        yield stream.connect(("127.0.0.1", self.get_http_port()))
        stream.write(b"GET / HTTP/1.0\r\n\r\n")

        data = yield stream.read_until_close()
        self.assertTrue(data.startswith(b"HTTP/1.1 200"))
        self.assertTrue(data.endswith(b"Hello"))

    @gen_test
    def test_read_zero_bytes(self):
        self.stream = self._make_client_iostream()
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")

        # normal read
        data = yield self.stream.read_bytes(9)
        self.assertEqual(data, b"HTTP/1.1 ")

        # zero bytes
        data = yield self.stream.read_bytes(0)
        self.assertEqual(data, b"")

        # another normal read
        data = yield self.stream.read_bytes(3)
        self.assertEqual(data, b"200")

        self.stream.close()

    @gen_test
    def test_write_while_connecting(self):
        stream = self._make_client_iostream()
        connect_fut = stream.connect(("127.0.0.1", self.get_http_port()))
        # unlike the previous tests, try to write before the connection
        # is complete.
        write_fut = stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
        self.assertFalse(connect_fut.done())

        # connect will always complete before write.
        it = gen.WaitIterator(connect_fut, write_fut)
        resolved_order = []
        while not it.done():
            yield it.next()
            resolved_order.append(it.current_future)
        self.assertEqual(resolved_order, [connect_fut, write_fut])

        data = yield stream.read_until_close()
        self.assertTrue(data.endswith(b"Hello"))

        stream.close()

    @gen_test
    def test_future_interface(self):
        """Basic test of IOStream's ability to return Futures."""
        stream = self._make_client_iostream()
        connect_result = yield stream.connect(("127.0.0.1", self.get_http_port()))
        self.assertIs(connect_result, stream)
        yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
        first_line = yield stream.read_until(b"\r\n")
        self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
        # callback=None is equivalent to no callback.
        header_data = yield stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_data.decode("latin1"))
        content_length = int(headers["Content-Length"])
        body = yield stream.read_bytes(content_length)
        self.assertEqual(body, b"Hello")
        stream.close()

    @gen_test
    def test_future_close_while_reading(self):
        stream = self._make_client_iostream()
        yield stream.connect(("127.0.0.1", self.get_http_port()))
        yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
        with self.assertRaises(StreamClosedError):
            yield stream.read_bytes(1024 * 1024)
        stream.close()

    @gen_test
    def test_future_read_until_close(self):
        # Ensure that the data comes through before the StreamClosedError.
        stream = self._make_client_iostream()
        yield stream.connect(("127.0.0.1", self.get_http_port()))
        yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
        yield stream.read_until(b"\r\n\r\n")
        body = yield stream.read_until_close()
        self.assertEqual(body, b"Hello")

        # Nothing else to read; the error comes immediately without waiting
        # for yield.
        with self.assertRaises(StreamClosedError):
            stream.read_bytes(1)


class TestReadWriteMixin(object):
    # Tests where one stream reads and the other writes.
    # These should work for BaseIOStream implementations.

    def make_iostream_pair(self, **kwargs):
        raise NotImplementedError

    @gen_test
    def test_write_zero_bytes(self):
        # Attempting to write zero bytes should run the callback without
        # going into an infinite loop.
        rs, ws = yield self.make_iostream_pair()
        yield ws.write(b"")
        ws.close()
        rs.close()

    @gen_test
    def test_future_delayed_close_callback(self):
        # Same as test_delayed_close_callback, but with the future interface.
        rs, ws = yield self.make_iostream_pair()

        try:
            ws.write(b"12")
            chunks = []
            chunks.append((yield rs.read_bytes(1)))
            ws.close()
            chunks.append((yield rs.read_bytes(1)))
            self.assertEqual(chunks, [b"1", b"2"])
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_close_buffered_data(self):
        # Similar to the previous test, but with data stored in the OS's
        # socket buffers instead of the IOStream's read buffer.  Out-of-band
        # close notifications must be delayed until all data has been
        # drained into the IOStream buffer. (epoll used to use out-of-band
        # close events with EPOLLRDHUP, but no longer)
        #
        # This depends on the read_chunk_size being smaller than the
        # OS socket buffer, so make it small.
        rs, ws = yield self.make_iostream_pair(read_chunk_size=256)
        try:
            ws.write(b"A" * 512)
            data = yield rs.read_bytes(256)
            self.assertEqual(b"A" * 256, data)
            ws.close()
            # Allow the close to propagate to the `rs` side of the
            # connection.  Using add_callback instead of add_timeout
            # doesn't seem to work, even with multiple iterations
            yield gen.sleep(0.01)
            data = yield rs.read_bytes(256)
            self.assertEqual(b"A" * 256, data)
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_close_after_close(self):
        # Similar to test_delayed_close_callback, but read_until_close takes
        # a separate code path so test it separately.
        rs, ws = yield self.make_iostream_pair()
        try:
            ws.write(b"1234")
            ws.close()
            # Read one byte to make sure the client has received the data.
            # It won't run the close callback as long as there is more buffered
            # data that could satisfy a later read.
            data = yield rs.read_bytes(1)
            self.assertEqual(data, b"1")
            data = yield rs.read_until_close()
            self.assertEqual(data, b"234")
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_large_read_until(self):
        # Performance test: read_until used to have a quadratic component
        # so a read_until of 4MB would take 8 seconds; now it takes 0.25
        # seconds.
        rs, ws = yield self.make_iostream_pair()
        try:
            # This test fails on pypy with ssl.  I think it's because
            # pypy's gc defeats moves objects, breaking the
            # "frozen write buffer" assumption.
            if (
                isinstance(rs, SSLIOStream)
                and platform.python_implementation() == "PyPy"
            ):
                raise unittest.SkipTest("pypy gc causes problems with openssl")
            NUM_KB = 4096
            for i in range(NUM_KB):
                ws.write(b"A" * 1024)
            ws.write(b"\r\n")
            data = yield rs.read_until(b"\r\n")
            self.assertEqual(len(data), NUM_KB * 1024 + 2)
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_close_callback_with_pending_read(self):
        # Regression test for a bug that was introduced in 2.3
        # where the IOStream._close_callback would never be called
        # if there were pending reads.
        OK = b"OK\r\n"
        rs, ws = yield self.make_iostream_pair()
        event = Event()
        rs.set_close_callback(event.set)
        try:
            ws.write(OK)
            res = yield rs.read_until(b"\r\n")
            self.assertEqual(res, OK)

            ws.close()
            rs.read_until(b"\r\n")
            # If _close_callback (self.stop) is not called,
            # an AssertionError: Async operation timed out after 5 seconds
            # will be raised.
            yield event.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_future_close_callback(self):
        # Regression test for interaction between the Future read interfaces
        # and IOStream._maybe_add_error_listener.
        rs, ws = yield self.make_iostream_pair()
        closed = [False]
        cond = Condition()

        def close_callback():
            closed[0] = True
            cond.notify()

        rs.set_close_callback(close_callback)
        try:
            ws.write(b"a")
            res = yield rs.read_bytes(1)
            self.assertEqual(res, b"a")
            self.assertFalse(closed[0])
            ws.close()
            yield cond.wait()
            self.assertTrue(closed[0])
        finally:
            rs.close()
            ws.close()

    @gen_test
    def test_write_memoryview(self):
        rs, ws = yield self.make_iostream_pair()
        try:
            fut = rs.read_bytes(4)
            ws.write(memoryview(b"hello"))
            data = yield fut
            self.assertEqual(data, b"hell")
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_bytes_partial(self):
        rs, ws = yield self.make_iostream_pair()
        try:
            # Ask for more than is available with partial=True
            fut = rs.read_bytes(50, partial=True)
            ws.write(b"hello")
            data = yield fut
            self.assertEqual(data, b"hello")

            # Ask for less than what is available; num_bytes is still
            # respected.
            fut = rs.read_bytes(3, partial=True)
            ws.write(b"world")
            data = yield fut
            self.assertEqual(data, b"wor")

            # Partial reads won't return an empty string, but read_bytes(0)
            # will.
            data = yield rs.read_bytes(0, partial=True)
            self.assertEqual(data, b"")
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_max_bytes(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Extra room under the limit
            fut = rs.read_until(b"def", max_bytes=50)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Just enough space
            fut = rs.read_until(b"def", max_bytes=6)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Not enough space, but we don't know it until all we can do is
            # log a warning and close the connection.
            with ExpectLog(gen_log, "Unsatisfiable read"):
                fut = rs.read_until(b"def", max_bytes=5)
                ws.write(b"123456")
                yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_max_bytes_inline(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Similar to the error case in the previous test, but the
            # ws writes first so rs reads are satisfied
            # inline.  For consistency with the out-of-line case, we
            # do not raise the error synchronously.
            ws.write(b"123456")
            with ExpectLog(gen_log, "Unsatisfiable read"):
                with self.assertRaises(StreamClosedError):
                    yield rs.read_until(b"def", max_bytes=5)
            yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_max_bytes_ignores_extra(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Even though data that matches arrives the same packet that
            # puts us over the limit, we fail the request because it was not
            # found within the limit.
            ws.write(b"abcdef")
            with ExpectLog(gen_log, "Unsatisfiable read"):
                rs.read_until(b"def", max_bytes=5)
                yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_regex_max_bytes(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Extra room under the limit
            fut = rs.read_until_regex(b"def", max_bytes=50)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Just enough space
            fut = rs.read_until_regex(b"def", max_bytes=6)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Not enough space, but we don't know it until all we can do is
            # log a warning and close the connection.
            with ExpectLog(gen_log, "Unsatisfiable read"):
                rs.read_until_regex(b"def", max_bytes=5)
                ws.write(b"123456")
                yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_regex_max_bytes_inline(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Similar to the error case in the previous test, but the
            # ws writes first so rs reads are satisfied
            # inline.  For consistency with the out-of-line case, we
            # do not raise the error synchronously.
            ws.write(b"123456")
            with ExpectLog(gen_log, "Unsatisfiable read"):
                rs.read_until_regex(b"def", max_bytes=5)
                yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_until_regex_max_bytes_ignores_extra(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Even though data that matches arrives the same packet that
            # puts us over the limit, we fail the request because it was not
            # found within the limit.
            ws.write(b"abcdef")
            with ExpectLog(gen_log, "Unsatisfiable read"):
                rs.read_until_regex(b"def", max_bytes=5)
                yield closed.wait()
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_small_reads_from_large_buffer(self):
        # 10KB buffer size, 100KB available to read.
        # Read 1KB at a time and make sure that the buffer is not eagerly
        # filled.
        rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024)
        try:
            ws.write(b"a" * 1024 * 100)
            for i in range(100):
                data = yield rs.read_bytes(1024)
                self.assertEqual(data, b"a" * 1024)
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_small_read_untils_from_large_buffer(self):
        # 10KB buffer size, 100KB available to read.
        # Read 1KB at a time and make sure that the buffer is not eagerly
        # filled.
        rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024)
        try:
            ws.write((b"a" * 1023 + b"\n") * 100)
            for i in range(100):
                data = yield rs.read_until(b"\n", max_bytes=4096)
                self.assertEqual(data, b"a" * 1023 + b"\n")
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_flow_control(self):
        MB = 1024 * 1024
        rs, ws = yield self.make_iostream_pair(max_buffer_size=5 * MB)
        try:
            # Client writes more than the rs will accept.
            ws.write(b"a" * 10 * MB)
            # The rs pauses while reading.
            yield rs.read_bytes(MB)
            yield gen.sleep(0.1)
            # The ws's writes have been blocked; the rs can
            # continue to read gradually.
            for i in range(9):
                yield rs.read_bytes(MB)
        finally:
            rs.close()
            ws.close()

    @gen_test
    def test_read_into(self):
        rs, ws = yield self.make_iostream_pair()

        def sleep_some():
            self.io_loop.run_sync(lambda: gen.sleep(0.05))

        try:
            buf = bytearray(10)
            fut = rs.read_into(buf)
            ws.write(b"hello")
            yield gen.sleep(0.05)
            self.assertTrue(rs.reading())
            ws.write(b"world!!")
            data = yield fut
            self.assertFalse(rs.reading())
            self.assertEqual(data, 10)
            self.assertEqual(bytes(buf), b"helloworld")

            # Existing buffer is fed into user buffer
            fut = rs.read_into(buf)
            yield gen.sleep(0.05)
            self.assertTrue(rs.reading())
            ws.write(b"1234567890")
            data = yield fut
            self.assertFalse(rs.reading())
            self.assertEqual(data, 10)
            self.assertEqual(bytes(buf), b"!!12345678")

            # Existing buffer can satisfy read immediately
            buf = bytearray(4)
            ws.write(b"abcdefghi")
            data = yield rs.read_into(buf)
            self.assertEqual(data, 4)
            self.assertEqual(bytes(buf), b"90ab")

            data = yield rs.read_bytes(7)
            self.assertEqual(data, b"cdefghi")
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_into_partial(self):
        rs, ws = yield self.make_iostream_pair()

        try:
            # Partial read
            buf = bytearray(10)
            fut = rs.read_into(buf, partial=True)
            ws.write(b"hello")
            data = yield fut
            self.assertFalse(rs.reading())
            self.assertEqual(data, 5)
            self.assertEqual(bytes(buf), b"hello\0\0\0\0\0")

            # Full read despite partial=True
            ws.write(b"world!1234567890")
            data = yield rs.read_into(buf, partial=True)
            self.assertEqual(data, 10)
            self.assertEqual(bytes(buf), b"world!1234")

            # Existing buffer can satisfy read immediately
            data = yield rs.read_into(buf, partial=True)
            self.assertEqual(data, 6)
            self.assertEqual(bytes(buf), b"5678901234")

        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_read_into_zero_bytes(self):
        rs, ws = yield self.make_iostream_pair()
        try:
            buf = bytearray()
            fut = rs.read_into(buf)
            self.assertEqual(fut.result(), 0)
        finally:
            ws.close()
            rs.close()

    @gen_test
    def test_many_mixed_reads(self):
        # Stress buffer handling when going back and forth between
        # read_bytes() (using an internal buffer) and read_into()
        # (using a user-allocated buffer).
        r = random.Random(42)
        nbytes = 1000000
        rs, ws = yield self.make_iostream_pair()

        produce_hash = hashlib.sha1()
        consume_hash = hashlib.sha1()

        @gen.coroutine
        def produce():
            remaining = nbytes
            while remaining > 0:
                size = r.randint(1, min(1000, remaining))
                data = os.urandom(size)
                produce_hash.update(data)
                yield ws.write(data)
                remaining -= size
            assert remaining == 0

        @gen.coroutine
        def consume():
            remaining = nbytes
            while remaining > 0:
                if r.random() > 0.5:
                    # read_bytes()
                    size = r.randint(1, min(1000, remaining))
                    data = yield rs.read_bytes(size)
                    consume_hash.update(data)
                    remaining -= size
                else:
                    # read_into()
                    size = r.randint(1, min(1000, remaining))
                    buf = bytearray(size)
                    n = yield rs.read_into(buf)
                    assert n == size
                    consume_hash.update(buf)
                    remaining -= size
            assert remaining == 0

        try:
            yield [produce(), consume()]
            assert produce_hash.hexdigest() == consume_hash.hexdigest()
        finally:
            ws.close()
            rs.close()


class TestIOStreamMixin(TestReadWriteMixin):
    def _make_server_iostream(self, connection, **kwargs):
        raise NotImplementedError()

    def _make_client_iostream(self, connection, **kwargs):
        raise NotImplementedError()

    @gen.coroutine
    def make_iostream_pair(self, **kwargs):
        listener, port = bind_unused_port()
        server_stream_fut = Future()  # type: Future[IOStream]

        def accept_callback(connection, address):
            server_stream_fut.set_result(
                self._make_server_iostream(connection, **kwargs)
            )

        netutil.add_accept_handler(listener, accept_callback)
        client_stream = self._make_client_iostream(socket.socket(), **kwargs)
        connect_fut = client_stream.connect(("127.0.0.1", port))
        server_stream, client_stream = yield [server_stream_fut, connect_fut]
        self.io_loop.remove_handler(listener.fileno())
        listener.close()
        raise gen.Return((server_stream, client_stream))

    @gen_test
    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        cleanup_func, port = refusing_port()
        self.addCleanup(cleanup_func)
        stream = IOStream(socket.socket())

        stream.set_close_callback(self.stop)
        # log messages vary by platform and ioloop implementation
        with ExpectLog(gen_log, ".*", required=False):
            with self.assertRaises(StreamClosedError):
                yield stream.connect(("127.0.0.1", port))

        self.assertTrue(isinstance(stream.error, socket.error), stream.error)
        if sys.platform != "cygwin":
            _ERRNO_CONNREFUSED = [errno.ECONNREFUSED]
            if hasattr(errno, "WSAECONNREFUSED"):
                _ERRNO_CONNREFUSED.append(errno.WSAECONNREFUSED)  # type: ignore
            # cygwin's errnos don't match those used on native windows python
            self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)  # type: ignore

    @gen_test
    def test_gaierror(self):
        # Test that IOStream sets its exc_info on getaddrinfo error.
        # It's difficult to reliably trigger a getaddrinfo error;
        # some resolvers own't even return errors for malformed names,
        # so we mock it instead. If IOStream changes to call a Resolver
        # before sock.connect, the mock target will need to change too.
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        stream = IOStream(s)
        stream.set_close_callback(self.stop)
        with mock.patch(
            "socket.socket.connect", side_effect=socket.gaierror(errno.EIO, "boom")
        ):
            with self.assertRaises(StreamClosedError):
                yield stream.connect(("localhost", 80))
            self.assertTrue(isinstance(stream.error, socket.gaierror))

    @gen_test
    def test_read_until_close_with_error(self):
        server, client = yield self.make_iostream_pair()
        try:
            with mock.patch(
                "tornado.iostream.BaseIOStream._try_inline_read",
                side_effect=IOError("boom"),
            ):
                with self.assertRaisesRegexp(IOError, "boom"):
                    client.read_until_close()
        finally:
            server.close()
            client.close()

    @skipIfNonUnix
    @skipPypy3V58
    @gen_test
    def test_inline_read_error(self):
        # An error on an inline read is raised without logging (on the
        # assumption that it will eventually be noticed or logged further
        # up the stack).
        #
        # This test is posix-only because windows os.close() doesn't work
        # on socket FDs, but we can't close the socket object normally
        # because we won't get the error we want if the socket knows
        # it's closed.
        server, client = yield self.make_iostream_pair()
        try:
            os.close(server.socket.fileno())
            with self.assertRaises(socket.error):
                server.read_bytes(1)
        finally:
            server.close()
            client.close()

    @skipPypy3V58
    @gen_test
    def test_async_read_error_logging(self):
        # Socket errors on asynchronous reads should be logged (but only
        # once).
        server, client = yield self.make_iostream_pair()
        closed = Event()
        server.set_close_callback(closed.set)
        try:
            # Start a read that will be fulfilled asynchronously.
            server.read_bytes(1)
            client.write(b"a")
            # Stub out read_from_fd to make it fail.

            def fake_read_from_fd():
                os.close(server.socket.fileno())
                server.__class__.read_from_fd(server)

            server.read_from_fd = fake_read_from_fd
            # This log message is from _handle_read (not read_from_fd).
            with ExpectLog(gen_log, "error on read"):
                yield closed.wait()
        finally:
            server.close()
            client.close()

    @gen_test
    def test_future_write(self):
        """
        Test that write() Futures are never orphaned.
        """
        # Run concurrent writers that will write enough bytes so as to
        # clog the socket buffer and accumulate bytes in our write buffer.
        m, n = 5000, 1000
        nproducers = 10
        total_bytes = m * n * nproducers
        server, client = yield self.make_iostream_pair(max_buffer_size=total_bytes)

        @gen.coroutine
        def produce():
            data = b"x" * m
            for i in range(n):
                yield server.write(data)

        @gen.coroutine
        def consume():
            nread = 0
            while nread < total_bytes:
                res = yield client.read_bytes(m)
                nread += len(res)

        try:
            yield [produce() for i in range(nproducers)] + [consume()]
        finally:
            server.close()
            client.close()


class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
    def _make_client_iostream(self):
        return IOStream(socket.socket())


class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
    def _make_client_iostream(self):
        return SSLIOStream(socket.socket(), ssl_options=dict(cert_reqs=ssl.CERT_NONE))


class TestIOStream(TestIOStreamMixin, AsyncTestCase):
    def _make_server_iostream(self, connection, **kwargs):
        return IOStream(connection, **kwargs)

    def _make_client_iostream(self, connection, **kwargs):
        return IOStream(connection, **kwargs)


class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
    def _make_server_iostream(self, connection, **kwargs):
        connection = ssl.wrap_socket(
            connection,
            server_side=True,
            do_handshake_on_connect=False,
            **_server_ssl_options()
        )
        return SSLIOStream(connection, **kwargs)

    def _make_client_iostream(self, connection, **kwargs):
        return SSLIOStream(
            connection, ssl_options=dict(cert_reqs=ssl.CERT_NONE), **kwargs
        )


# This will run some tests that are basically redundant but it's the
# simplest way to make sure that it works to pass an SSLContext
# instead of an ssl_options dict to the SSLIOStream constructor.
class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
    def _make_server_iostream(self, connection, **kwargs):
        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        context.load_cert_chain(
            os.path.join(os.path.dirname(__file__), "test.crt"),
            os.path.join(os.path.dirname(__file__), "test.key"),
        )
        connection = ssl_wrap_socket(
            connection, context, server_side=True, do_handshake_on_connect=False
        )
        return SSLIOStream(connection, **kwargs)

    def _make_client_iostream(self, connection, **kwargs):
        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        return SSLIOStream(connection, ssl_options=context, **kwargs)


class TestIOStreamStartTLS(AsyncTestCase):
    def setUp(self):
        try:
            super(TestIOStreamStartTLS, self).setUp()
            self.listener, self.port = bind_unused_port()
            self.server_stream = None
            self.server_accepted = Future()  # type: Future[None]
            netutil.add_accept_handler(self.listener, self.accept)
            self.client_stream = IOStream(socket.socket())
            self.io_loop.add_future(
                self.client_stream.connect(("127.0.0.1", self.port)), self.stop
            )
            self.wait()
            self.io_loop.add_future(self.server_accepted, self.stop)
            self.wait()
        except Exception as e:
            print(e)
            raise

    def tearDown(self):
        if self.server_stream is not None:
            self.server_stream.close()
        if self.client_stream is not None:
            self.client_stream.close()
        self.listener.close()
        super(TestIOStreamStartTLS, self).tearDown()

    def accept(self, connection, address):
        if self.server_stream is not None:
            self.fail("should only get one connection")
        self.server_stream = IOStream(connection)
        self.server_accepted.set_result(None)

    @gen.coroutine
    def client_send_line(self, line):
        self.client_stream.write(line)
        recv_line = yield self.server_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    @gen.coroutine
    def server_send_line(self, line):
        self.server_stream.write(line)
        recv_line = yield self.client_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    def client_start_tls(self, ssl_options=None, server_hostname=None):
        client_stream = self.client_stream
        self.client_stream = None
        return client_stream.start_tls(False, ssl_options, server_hostname)

    def server_start_tls(self, ssl_options=None):
        server_stream = self.server_stream
        self.server_stream = None
        return server_stream.start_tls(True, ssl_options)

    @gen_test
    def test_start_tls_smtp(self):
        # This flow is simplified from RFC 3207 section 5.
        # We don't really need all of this, but it helps to make sure
        # that after realistic back-and-forth traffic the buffers end up
        # in a sane state.
        yield self.server_send_line(b"220 mail.example.com ready\r\n")
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
        yield self.server_send_line(b"250 STARTTLS\r\n")
        yield self.client_send_line(b"STARTTLS\r\n")
        yield self.server_send_line(b"220 Go ahead\r\n")
        client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
        server_future = self.server_start_tls(_server_ssl_options())
        self.client_stream = yield client_future
        self.server_stream = yield server_future
        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250 mail.example.com welcome\r\n")

    @gen_test
    def test_handshake_fail(self):
        server_future = self.server_start_tls(_server_ssl_options())
        # Certificates are verified with the default configuration.
        with ExpectLog(gen_log, "SSL Error"):
            client_future = self.client_start_tls(server_hostname="localhost")
            with self.assertRaises(ssl.SSLError):
                yield client_future
            with self.assertRaises((ssl.SSLError, socket.error)):
                yield server_future

    @gen_test
    def test_check_hostname(self):
        # Test that server_hostname parameter to start_tls is being used.
        # The check_hostname functionality is only available in python 2.7 and
        # up and in python 3.4 and up.
        server_future = self.server_start_tls(_server_ssl_options())
        with ExpectLog(gen_log, "SSL Error"):
            client_future = self.client_start_tls(
                ssl.create_default_context(), server_hostname="127.0.0.1"
            )
            with self.assertRaises(ssl.SSLError):
                # The client fails to connect with an SSL error.
                yield client_future
            with self.assertRaises(Exception):
                # The server fails to connect, but the exact error is unspecified.
                yield server_future


class WaitForHandshakeTest(AsyncTestCase):
    @gen.coroutine
    def connect_to_server(self, server_cls):
        server = client = None
        try:
            sock, port = bind_unused_port()
            server = server_cls(ssl_options=_server_ssl_options())
            server.add_socket(sock)

            client = SSLIOStream(
                socket.socket(), ssl_options=dict(cert_reqs=ssl.CERT_NONE)
            )
            yield client.connect(("127.0.0.1", port))
            self.assertIsNotNone(client.socket.cipher())
        finally:
            if server is not None:
                server.stop()
            if client is not None:
                client.close()

    @gen_test
    def test_wait_for_handshake_future(self):
        test = self
        handshake_future = Future()  # type: Future[None]

        class TestServer(TCPServer):
            def handle_stream(self, stream, address):
                test.assertIsNone(stream.socket.cipher())
                test.io_loop.spawn_callback(self.handle_connection, stream)

            @gen.coroutine
            def handle_connection(self, stream):
                yield stream.wait_for_handshake()
                handshake_future.set_result(None)

        yield self.connect_to_server(TestServer)
        yield handshake_future

    @gen_test
    def test_wait_for_handshake_already_waiting_error(self):
        test = self
        handshake_future = Future()  # type: Future[None]

        class TestServer(TCPServer):
            @gen.coroutine
            def handle_stream(self, stream, address):
                fut = stream.wait_for_handshake()
                test.assertRaises(RuntimeError, stream.wait_for_handshake)
                yield fut

                handshake_future.set_result(None)

        yield self.connect_to_server(TestServer)
        yield handshake_future

    @gen_test
    def test_wait_for_handshake_already_connected(self):
        handshake_future = Future()  # type: Future[None]

        class TestServer(TCPServer):
            @gen.coroutine
            def handle_stream(self, stream, address):
                yield stream.wait_for_handshake()
                yield stream.wait_for_handshake()
                handshake_future.set_result(None)

        yield self.connect_to_server(TestServer)
        yield handshake_future


@skipIfNonUnix
class TestPipeIOStream(TestReadWriteMixin, AsyncTestCase):
    @gen.coroutine
    def make_iostream_pair(self, **kwargs):
        r, w = os.pipe()

        return PipeIOStream(r, **kwargs), PipeIOStream(w, **kwargs)

    @gen_test
    def test_pipe_iostream(self):
        rs, ws = yield self.make_iostream_pair()

        ws.write(b"hel")
        ws.write(b"lo world")

        data = yield rs.read_until(b" ")
        self.assertEqual(data, b"hello ")

        data = yield rs.read_bytes(3)
        self.assertEqual(data, b"wor")

        ws.close()

        data = yield rs.read_until_close()
        self.assertEqual(data, b"ld")

        rs.close()

    @gen_test
    def test_pipe_iostream_big_write(self):
        rs, ws = yield self.make_iostream_pair()

        NUM_BYTES = 1048576

        # Write 1MB of data, which should fill the buffer
        ws.write(b"1" * NUM_BYTES)

        data = yield rs.read_bytes(NUM_BYTES)
        self.assertEqual(data, b"1" * NUM_BYTES)

        ws.close()
        rs.close()


class TestStreamBuffer(unittest.TestCase):
    """
    Unit tests for the private _StreamBuffer class.
    """

    def setUp(self):
        self.random = random.Random(42)

    def to_bytes(self, b):
        if isinstance(b, (bytes, bytearray)):
            return bytes(b)
        elif isinstance(b, memoryview):
            return b.tobytes()  # For py2
        else:
            raise TypeError(b)

    def make_streambuffer(self, large_buf_threshold=10):
        buf = _StreamBuffer()
        assert buf._large_buf_threshold
        buf._large_buf_threshold = large_buf_threshold
        return buf

    def check_peek(self, buf, expected):
        size = 1
        while size < 2 * len(expected):
            got = self.to_bytes(buf.peek(size))
            self.assertTrue(got)  # Not empty
            self.assertLessEqual(len(got), size)
            self.assertTrue(expected.startswith(got), (expected, got))
            size = (size * 3 + 1) // 2

    def check_append_all_then_skip_all(self, buf, objs, input_type):
        self.assertEqual(len(buf), 0)

        expected = b""

        for o in objs:
            expected += o
            buf.append(input_type(o))
            self.assertEqual(len(buf), len(expected))
            self.check_peek(buf, expected)

        while expected:
            n = self.random.randrange(1, len(expected) + 1)
            expected = expected[n:]
            buf.advance(n)
            self.assertEqual(len(buf), len(expected))
            self.check_peek(buf, expected)

        self.assertEqual(len(buf), 0)

    def test_small(self):
        objs = [b"12", b"345", b"67", b"89a", b"bcde", b"fgh", b"ijklmn"]

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, bytes)

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, bytearray)

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, memoryview)

        # Test internal algorithm
        buf = self.make_streambuffer(10)
        for i in range(9):
            buf.append(b"x")
        self.assertEqual(len(buf._buffers), 1)
        for i in range(9):
            buf.append(b"x")
        self.assertEqual(len(buf._buffers), 2)
        buf.advance(10)
        self.assertEqual(len(buf._buffers), 1)
        buf.advance(8)
        self.assertEqual(len(buf._buffers), 0)
        self.assertEqual(len(buf), 0)

    def test_large(self):
        objs = [
            b"12" * 5,
            b"345" * 2,
            b"67" * 20,
            b"89a" * 12,
            b"bcde" * 1,
            b"fgh" * 7,
            b"ijklmn" * 2,
        ]

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, bytes)

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, bytearray)

        buf = self.make_streambuffer()
        self.check_append_all_then_skip_all(buf, objs, memoryview)

        # Test internal algorithm
        buf = self.make_streambuffer(10)
        for i in range(3):
            buf.append(b"x" * 11)
        self.assertEqual(len(buf._buffers), 3)
        buf.append(b"y")
        self.assertEqual(len(buf._buffers), 4)
        buf.append(b"z")
        self.assertEqual(len(buf._buffers), 4)
        buf.advance(33)
        self.assertEqual(len(buf._buffers), 1)
        buf.advance(2)
        self.assertEqual(len(buf._buffers), 0)
        self.assertEqual(len(buf), 0)
