# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import ast
import base64
import itertools
import os
import pathlib
import signal
import struct
import tempfile
import threading
import time
import traceback
import json

try:
    import numpy as np
except ImportError:
    np = None
import pytest
import pyarrow as pa

from pyarrow.lib import IpcReadOptions, tobytes
from pyarrow.util import find_free_port
from pyarrow.tests import util

try:
    from pyarrow import flight
    from pyarrow.flight import (
        FlightClient, FlightServerBase,
        ServerAuthHandler, ClientAuthHandler,
        ServerMiddleware, ServerMiddlewareFactory,
        ClientMiddleware, ClientMiddlewareFactory,
    )
except ImportError:
    flight = None
    FlightClient, FlightServerBase = object, object
    ServerAuthHandler, ClientAuthHandler = object, object
    ServerMiddleware, ServerMiddlewareFactory = object, object
    ClientMiddleware, ClientMiddlewareFactory = object, object

# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not flight'
pytestmark = pytest.mark.flight


def test_import():
    # So we see the ImportError somewhere
    import pyarrow.flight  # noqa


def resource_root():
    """Get the path to the test resources directory."""
    if not os.environ.get("ARROW_TEST_DATA"):
        raise RuntimeError("Test resources not found; set "
                           "ARROW_TEST_DATA to <repo root>/testing/data")
    return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight"


def read_flight_resource(path):
    """Get the contents of a test resource file."""
    root = resource_root()
    if not root:
        return None
    try:
        with (root / path).open("rb") as f:
            return f.read()
    except FileNotFoundError:
        raise RuntimeError(
            "Test resource {} not found; did you initialize the "
            "test resource submodule?\n{}".format(root / path,
                                                  traceback.format_exc()))


def example_tls_certs():
    """Get the paths to test TLS certificates."""
    return {
        "root_cert": read_flight_resource("root-ca.pem"),
        "certificates": [
            flight.CertKeyPair(
                cert=read_flight_resource("cert0.pem"),
                key=read_flight_resource("cert0.key"),
            ),
            flight.CertKeyPair(
                cert=read_flight_resource("cert1.pem"),
                key=read_flight_resource("cert1.key"),
            ),
        ]
    }


def simple_ints_table():
    data = [
        pa.array([-10, -5, 0, 5, 10])
    ]
    return pa.Table.from_arrays(data, names=['some_ints'])


def simple_dicts_table():
    dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8())
    data = [
        pa.chunked_array([
            pa.DictionaryArray.from_arrays([1, 0, None], dict_values),
            pa.DictionaryArray.from_arrays([2, 1], dict_values)
        ])
    ]
    return pa.Table.from_arrays(data, names=['some_dicts'])


def multiple_column_table():
    return pa.Table.from_arrays([pa.array(['foo', 'bar', 'baz', 'qux']),
                                 pa.array([1, 2, 3, 4])],
                                names=['a', 'b'])


class ConstantFlightServer(FlightServerBase):
    """A Flight server that always returns the same data.

    See ARROW-4796: this server implementation will segfault if Flight
    does not properly hold a reference to the Table object.
    """

    CRITERIA = b"the expected criteria"

    def __init__(self, location=None, options=None, **kwargs):
        super().__init__(location, **kwargs)
        # Ticket -> Table
        self.table_factories = {
            b'ints': simple_ints_table,
            b'dicts': simple_dicts_table,
            b'multi': multiple_column_table,
        }
        self.options = options

    def list_flights(self, context, criteria):
        if criteria == self.CRITERIA:
            yield flight.FlightInfo(
                pa.schema([]),
                flight.FlightDescriptor.for_path('/foo'),
                [],
                -1, -1
            )

    def do_get(self, context, ticket):
        # Return a fresh table, so that Flight is the only one keeping a
        # reference.
        table = self.table_factories[ticket.ticket]()
        return flight.RecordBatchStream(table, options=self.options)


class MetadataFlightServer(FlightServerBase):
    """A Flight server that numbers incoming/outgoing data."""

    def __init__(self, options=None, **kwargs):
        super().__init__(**kwargs)
        self.options = options

    def do_get(self, context, ticket):
        data = [
            pa.array([-10, -5, 0, 5, 10])
        ]
        table = pa.Table.from_arrays(data, names=['a'])
        return flight.GeneratorStream(
            table.schema,
            self.number_batches(table),
            options=self.options)

    def do_put(self, context, descriptor, reader, writer):
        counter = 0
        expected_data = [-10, -5, 0, 5, 10]
        for batch, buf in reader:
            assert batch.equals(pa.RecordBatch.from_arrays(
                [pa.array([expected_data[counter]])],
                ['a']
            ))
            assert buf is not None
            client_counter, = struct.unpack('<i', buf.to_pybytes())
            assert counter == client_counter
            writer.write(struct.pack('<i', counter))
            counter += 1

    @staticmethod
    def number_batches(table):
        for idx, batch in enumerate(table.to_batches()):
            buf = struct.pack('<i', idx)
            yield batch, buf


class EchoFlightServer(FlightServerBase):
    """A Flight server that returns the last data uploaded."""

    def __init__(self, location=None, expected_schema=None, **kwargs):
        super().__init__(location, **kwargs)
        self.last_message = None
        self.expected_schema = expected_schema

    def do_get(self, context, ticket):
        return flight.RecordBatchStream(self.last_message)

    def do_put(self, context, descriptor, reader, writer):
        if self.expected_schema:
            assert self.expected_schema == reader.schema
        self.last_message = reader.read_all()

    def do_exchange(self, context, descriptor, reader, writer):
        for chunk in reader:
            pass


class EchoStreamFlightServer(EchoFlightServer):
    """An echo server that streams individual record batches."""

    def do_get(self, context, ticket):
        return flight.GeneratorStream(
            self.last_message.schema,
            self.last_message.to_batches(max_chunksize=1024))

    def list_actions(self, context):
        return []

    def do_action(self, context, action):
        if action.type == "who-am-i":
            return [context.peer_identity(), context.peer().encode("utf-8")]
        raise NotImplementedError


class GetInfoFlightServer(FlightServerBase):
    """A Flight server that tests GetFlightInfo."""

    def get_flight_info(self, context, descriptor):
        return flight.FlightInfo(
            pa.schema([('a', pa.int32())]),
            descriptor,
            [
                flight.FlightEndpoint(b'', ['grpc://test']),
                flight.FlightEndpoint(
                    b'',
                    [flight.Location.for_grpc_tcp('localhost', 5005)],
                ),
            ],
            -1,
            -1,
        )

    def get_schema(self, context, descriptor):
        info = self.get_flight_info(context, descriptor)
        return flight.SchemaResult(info.schema)


class ListActionsFlightServer(FlightServerBase):
    """A Flight server that tests ListActions."""

    @classmethod
    def expected_actions(cls):
        return [
            ("action-1", "description"),
            ("action-2", ""),
            flight.ActionType("action-3", "more detail"),
        ]

    def list_actions(self, context):
        yield from self.expected_actions()


class ListActionsErrorFlightServer(FlightServerBase):
    """A Flight server that tests ListActions."""

    def list_actions(self, context):
        yield ("action-1", "")
        yield "foo"


class CheckTicketFlightServer(FlightServerBase):
    """A Flight server that compares the given ticket to an expected value."""

    def __init__(self, expected_ticket, location=None, **kwargs):
        super().__init__(location, **kwargs)
        self.expected_ticket = expected_ticket

    def do_get(self, context, ticket):
        assert self.expected_ticket == ticket.ticket
        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
        table = pa.Table.from_arrays(data1, names=['a'])
        return flight.RecordBatchStream(table)

    def do_put(self, context, descriptor, reader):
        self.last_message = reader.read_all()


class InvalidStreamFlightServer(FlightServerBase):
    """A Flight server that tries to return messages with differing schemas."""

    schema = pa.schema([('a', pa.int32())])

    def do_get(self, context, ticket):
        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
        data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())]
        assert data1.type != data2.type
        table1 = pa.Table.from_arrays(data1, names=['a'])
        table2 = pa.Table.from_arrays(data2, names=['a'])
        assert table1.schema == self.schema

        return flight.GeneratorStream(self.schema, [table1, table2])


class NeverSendsDataFlightServer(FlightServerBase):
    """A Flight server that never actually yields data."""

    schema = pa.schema([('a', pa.int32())])

    def do_get(self, context, ticket):
        if ticket.ticket == b'yield_data':
            # Check that the server handler will ignore empty tables
            # up to a certain extent
            data = [
                self.schema.empty_table(),
                self.schema.empty_table(),
                pa.RecordBatch.from_arrays([range(5)], schema=self.schema),
            ]
            return flight.GeneratorStream(self.schema, data)
        return flight.GeneratorStream(
            self.schema, itertools.repeat(self.schema.empty_table()))


class SlowFlightServer(FlightServerBase):
    """A Flight server that delays its responses to test timeouts."""

    def do_get(self, context, ticket):
        return flight.GeneratorStream(pa.schema([('a', pa.int32())]),
                                      self.slow_stream())

    def do_action(self, context, action):
        time.sleep(0.5)
        return []

    @staticmethod
    def slow_stream():
        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
        yield pa.Table.from_arrays(data1, names=['a'])
        # The second message should never get sent; the client should
        # cancel before we send this
        time.sleep(10)
        yield pa.Table.from_arrays(data1, names=['a'])


class ErrorFlightServer(FlightServerBase):
    """A Flight server that uses all the Flight-specific errors."""

    @staticmethod
    def error_cases():
        return {
            "internal": flight.FlightInternalError,
            "timedout": flight.FlightTimedOutError,
            "cancel": flight.FlightCancelledError,
            "unauthenticated": flight.FlightUnauthenticatedError,
            "unauthorized": flight.FlightUnauthorizedError,
            "notimplemented": NotImplementedError,
            "invalid": pa.ArrowInvalid,
            "key": KeyError,
        }

    def do_action(self, context, action):
        error_cases = ErrorFlightServer.error_cases()
        if action.type in error_cases:
            raise error_cases[action.type]("foo")
        elif action.type == "protobuf":
            err_msg = b'this is an error message'
            raise flight.FlightUnauthorizedError("foo", err_msg)
        raise NotImplementedError

    def list_flights(self, context, criteria):
        yield flight.FlightInfo(
            pa.schema([]),
            flight.FlightDescriptor.for_path('/foo'),
            [],
            -1, -1
        )
        raise flight.FlightInternalError("foo")

    def do_put(self, context, descriptor, reader, writer):
        if descriptor.command == b"internal":
            raise flight.FlightInternalError("foo")
        elif descriptor.command == b"timedout":
            raise flight.FlightTimedOutError("foo")
        elif descriptor.command == b"cancel":
            raise flight.FlightCancelledError("foo")
        elif descriptor.command == b"unauthenticated":
            raise flight.FlightUnauthenticatedError("foo")
        elif descriptor.command == b"unauthorized":
            raise flight.FlightUnauthorizedError("foo")
        elif descriptor.command == b"protobuf":
            err_msg = b'this is an error message'
            raise flight.FlightUnauthorizedError("foo", err_msg)


class ExchangeFlightServer(FlightServerBase):
    """A server for testing DoExchange."""

    def __init__(self, options=None, **kwargs):
        super().__init__(**kwargs)
        self.options = options

    def do_exchange(self, context, descriptor, reader, writer):
        if descriptor.descriptor_type != flight.DescriptorType.CMD:
            raise pa.ArrowInvalid("Must provide a command descriptor")
        elif descriptor.command == b"echo":
            return self.exchange_echo(context, reader, writer)
        elif descriptor.command == b"get":
            return self.exchange_do_get(context, reader, writer)
        elif descriptor.command == b"put":
            return self.exchange_do_put(context, reader, writer)
        elif descriptor.command == b"transform":
            return self.exchange_transform(context, reader, writer)
        else:
            raise pa.ArrowInvalid(
                "Unknown command: {}".format(descriptor.command))

    def exchange_do_get(self, context, reader, writer):
        """Emulate DoGet with DoExchange."""
        data = pa.Table.from_arrays([
            pa.array(range(0, 10 * 1024))
        ], names=["a"])
        writer.begin(data.schema)
        writer.write_table(data)

    def exchange_do_put(self, context, reader, writer):
        """Emulate DoPut with DoExchange."""
        num_batches = 0
        for chunk in reader:
            if not chunk.data:
                raise pa.ArrowInvalid("All chunks must have data.")
            num_batches += 1
        writer.write_metadata(str(num_batches).encode("utf-8"))

    def exchange_echo(self, context, reader, writer):
        """Run a simple echo server."""
        started = False
        for chunk in reader:
            if not started and chunk.data:
                writer.begin(chunk.data.schema, options=self.options)
                started = True
            if chunk.app_metadata and chunk.data:
                writer.write_with_metadata(chunk.data, chunk.app_metadata)
            elif chunk.app_metadata:
                writer.write_metadata(chunk.app_metadata)
            elif chunk.data:
                writer.write_batch(chunk.data)
            else:
                assert False, "Should not happen"

    def exchange_transform(self, context, reader, writer):
        """Sum rows in an uploaded table."""
        for field in reader.schema:
            if not pa.types.is_integer(field.type):
                raise pa.ArrowInvalid("Invalid field: " + repr(field))
        table = reader.read_all()
        sums = [0] * table.num_rows
        for column in table:
            for row, value in enumerate(column):
                sums[row] += value.as_py()
        result = pa.Table.from_arrays([pa.array(sums)], names=["sum"])
        writer.begin(result.schema)
        writer.write_table(result)


class HttpBasicServerAuthHandler(ServerAuthHandler):
    """An example implementation of HTTP basic authentication."""

    def __init__(self, creds):
        super().__init__()
        self.creds = creds

    def authenticate(self, outgoing, incoming):
        buf = incoming.read()
        auth = flight.BasicAuth.deserialize(buf)
        if auth.username not in self.creds:
            raise flight.FlightUnauthenticatedError("unknown user")
        if self.creds[auth.username] != auth.password:
            raise flight.FlightUnauthenticatedError("wrong password")
        outgoing.write(tobytes(auth.username))

    def is_valid(self, token):
        if not token:
            raise flight.FlightUnauthenticatedError("token not provided")
        if token not in self.creds:
            raise flight.FlightUnauthenticatedError("unknown user")
        return token


class HttpBasicClientAuthHandler(ClientAuthHandler):
    """An example implementation of HTTP basic authentication."""

    def __init__(self, username, password):
        super().__init__()
        self.basic_auth = flight.BasicAuth(username, password)
        self.token = None

    def authenticate(self, outgoing, incoming):
        auth = self.basic_auth.serialize()
        outgoing.write(auth)
        self.token = incoming.read()

    def get_token(self):
        return self.token


class TokenServerAuthHandler(ServerAuthHandler):
    """An example implementation of authentication via handshake."""

    def __init__(self, creds):
        super().__init__()
        self.creds = creds

    def authenticate(self, outgoing, incoming):
        username = incoming.read()
        password = incoming.read()
        if username in self.creds and self.creds[username] == password:
            outgoing.write(base64.b64encode(b'secret:' + username))
        else:
            raise flight.FlightUnauthenticatedError(
                "invalid username/password")

    def is_valid(self, token):
        token = base64.b64decode(token)
        if not token.startswith(b'secret:'):
            raise flight.FlightUnauthenticatedError("invalid token")
        return token[7:]


class TokenClientAuthHandler(ClientAuthHandler):
    """An example implementation of authentication via handshake."""

    def __init__(self, username, password):
        super().__init__()
        self.username = username
        self.password = password
        self.token = b''

    def authenticate(self, outgoing, incoming):
        outgoing.write(self.username)
        outgoing.write(self.password)
        self.token = incoming.read()

    def get_token(self):
        return self.token


class NoopAuthHandler(ServerAuthHandler):
    """A no-op auth handler."""

    def authenticate(self, outgoing, incoming):
        """Do nothing."""

    def is_valid(self, token):
        """
        Returning an empty string.
        Returning None causes Type error.
        """
        return ""


def case_insensitive_header_lookup(headers, lookup_key):
    """Lookup the value of given key in the given headers.
       The key lookup is case-insensitive.
    """
    for key in headers:
        if key.lower() == lookup_key.lower():
            return headers.get(key)


class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory):
    """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware."""

    def __init__(self):
        self.call_credential = []

    def start_call(self, info):
        return ClientHeaderAuthMiddleware(self)

    def set_call_credential(self, call_credential):
        self.call_credential = call_credential


class ClientHeaderAuthMiddleware(ClientMiddleware):
    """
    ClientMiddleware that extracts the authorization header
    from the server.

    This is an example of a ClientMiddleware that can extract
    the bearer token authorization header from a HTTP header
    authentication enabled server.

    Parameters
    ----------
    factory : ClientHeaderAuthMiddlewareFactory
        This factory is used to set call credentials if an
        authorization header is found in the headers from the server.
    """

    def __init__(self, factory):
        self.factory = factory

    def received_headers(self, headers):
        auth_header = case_insensitive_header_lookup(headers, 'Authorization')
        self.factory.set_call_credential([
            b'authorization',
            auth_header[0].encode("utf-8")])


class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory):
    """Validates incoming username and password."""

    def start_call(self, info, headers):
        auth_header = case_insensitive_header_lookup(
            headers,
            'Authorization'
        )
        values = auth_header[0].split(' ')
        token = ''
        error_message = 'Invalid credentials'

        if values[0] == 'Basic':
            decoded = base64.b64decode(values[1])
            pair = decoded.decode("utf-8").split(':')
            if not (pair[0] == 'test' and pair[1] == 'password'):
                raise flight.FlightUnauthenticatedError(error_message)
            token = 'token1234'
        elif values[0] == 'Bearer':
            token = values[1]
            if not token == 'token1234':
                raise flight.FlightUnauthenticatedError(error_message)
        else:
            raise flight.FlightUnauthenticatedError(error_message)

        return HeaderAuthServerMiddleware(token)


class HeaderAuthServerMiddleware(ServerMiddleware):
    """A ServerMiddleware that transports incoming username and password."""

    def __init__(self, token):
        self.token = token

    def sending_headers(self):
        return {'authorization': 'Bearer ' + self.token}


class HeaderAuthFlightServer(FlightServerBase):
    """A Flight server that tests with basic token authentication. """

    def do_action(self, context, action):
        middleware = context.get_middleware("auth")
        if middleware:
            auth_header = case_insensitive_header_lookup(
                middleware.sending_headers(), 'Authorization')
            values = auth_header.split(' ')
            return [values[1].encode("utf-8")]
        raise flight.FlightUnauthenticatedError(
            'No token auth middleware found.')


class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory):
    """A ServerMiddlewareFactory that transports arbitrary headers."""

    def start_call(self, info, headers):
        return ArbitraryHeadersServerMiddleware(headers)


class ArbitraryHeadersServerMiddleware(ServerMiddleware):
    """A ServerMiddleware that transports arbitrary headers."""

    def __init__(self, incoming):
        self.incoming = incoming

    def sending_headers(self):
        return self.incoming


class ArbitraryHeadersFlightServer(FlightServerBase):
    """A Flight server that tests multiple arbitrary headers."""

    def do_action(self, context, action):
        middleware = context.get_middleware("arbitrary-headers")
        if middleware:
            headers = middleware.sending_headers()
            header_1 = case_insensitive_header_lookup(
                headers,
                'test-header-1'
            )
            header_2 = case_insensitive_header_lookup(
                headers,
                'test-header-2'
            )
            value1 = header_1[0].encode("utf-8")
            value2 = header_2[0].encode("utf-8")
            return [value1, value2]
        raise flight.FlightServerError("No headers middleware found")


class HeaderServerMiddleware(ServerMiddleware):
    """Expose a per-call value to the RPC method body."""

    def __init__(self, special_value):
        self.special_value = special_value


class HeaderServerMiddlewareFactory(ServerMiddlewareFactory):
    """Expose a per-call hard-coded value to the RPC method body."""

    def start_call(self, info, headers):
        return HeaderServerMiddleware("right value")


class HeaderFlightServer(FlightServerBase):
    """Echo back the per-call hard-coded value."""

    def do_action(self, context, action):
        middleware = context.get_middleware("test")
        if middleware:
            return [middleware.special_value.encode()]
        return [b""]


class MultiHeaderFlightServer(FlightServerBase):
    """Test sending/receiving multiple (binary-valued) headers."""

    def do_action(self, context, action):
        middleware = context.get_middleware("test")
        headers = repr(middleware.client_headers).encode("utf-8")
        return [headers]


class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory):
    """Deny access to certain methods based on a header."""

    def start_call(self, info, headers):
        if info.method == flight.FlightMethod.LIST_ACTIONS:
            # No auth needed
            return

        token = headers.get("x-auth-token")
        if not token:
            raise flight.FlightUnauthenticatedError("No token")

        token = token[0]
        if token != "password":
            raise flight.FlightUnauthenticatedError("Invalid token")

        return HeaderServerMiddleware(token)


class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory):
    def start_call(self, info):
        return SelectiveAuthClientMiddleware()


class SelectiveAuthClientMiddleware(ClientMiddleware):
    def sending_headers(self):
        return {
            "x-auth-token": "password",
        }


class RecordingServerMiddlewareFactory(ServerMiddlewareFactory):
    """Record what methods were called."""

    def __init__(self):
        super().__init__()
        self.methods = []

    def start_call(self, info, headers):
        self.methods.append(info.method)
        return None


class RecordingClientMiddlewareFactory(ClientMiddlewareFactory):
    """Record what methods were called."""

    def __init__(self):
        super().__init__()
        self.methods = []

    def start_call(self, info):
        self.methods.append(info.method)
        return None


class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory):
    """Test sending/receiving multiple (binary-valued) headers."""

    def __init__(self):
        # Read in test_middleware_multi_header below.
        # The middleware instance will update this value.
        self.last_headers = {}

    def start_call(self, info):
        return MultiHeaderClientMiddleware(self)


class MultiHeaderClientMiddleware(ClientMiddleware):
    """Test sending/receiving multiple (binary-valued) headers."""

    EXPECTED = {
        "x-text": ["foo", "bar"],
        "x-binary-bin": [b"\x00", b"\x01"],
        # ARROW-16606: ensure mixed-case headers are accepted
        "x-MIXED-case": ["baz"],
        b"x-other-MIXED-case": ["baz"],
    }

    def __init__(self, factory):
        self.factory = factory

    def sending_headers(self):
        return self.EXPECTED

    def received_headers(self, headers):
        # Let the test code know what the last set of headers we
        # received were.
        self.factory.last_headers.update(headers)


class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
    """Test sending/receiving multiple (binary-valued) headers."""

    def start_call(self, info, headers):
        return MultiHeaderServerMiddleware(headers)


class MultiHeaderServerMiddleware(ServerMiddleware):
    """Test sending/receiving multiple (binary-valued) headers."""

    def __init__(self, client_headers):
        self.client_headers = client_headers

    def sending_headers(self):
        return MultiHeaderClientMiddleware.EXPECTED


class LargeMetadataFlightServer(FlightServerBase):
    """Regression test for ARROW-13253."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._metadata = b' ' * (2 ** 31 + 1)

    def do_get(self, context, ticket):
        schema = pa.schema([('a', pa.int64())])
        return flight.GeneratorStream(schema, [
            (pa.record_batch([[1]], schema=schema), self._metadata),
        ])

    def do_exchange(self, context, descriptor, reader, writer):
        writer.write_metadata(self._metadata)


def test_repr():
    action_repr = "<pyarrow.flight.Action type='foo' body=(0 bytes)>"
    action_type_repr = "ActionType(type='foo', description='bar')"
    basic_auth_repr = "<pyarrow.flight.BasicAuth username=b'user' password=(redacted)>"
    descriptor_repr = "<pyarrow.flight.FlightDescriptor cmd=b'foo'>"
    endpoint_repr = ("<pyarrow.flight.FlightEndpoint "
                     "ticket=<pyarrow.flight.Ticket ticket=b'foo'> "
                     "locations=[]>")
    info_repr = (
        "<pyarrow.flight.FlightInfo "
        "schema= "
        "descriptor=<pyarrow.flight.FlightDescriptor path=[]> "
        "endpoints=[] "
        "total_records=-1 "
        "total_bytes=-1>")
    location_repr = "<pyarrow.flight.Location b'grpc+tcp://localhost:1234'>"
    result_repr = "<pyarrow.flight.Result body=(3 bytes)>"
    schema_result_repr = "<pyarrow.flight.SchemaResult schema=()>"
    ticket_repr = "<pyarrow.flight.Ticket ticket=b'foo'>"

    assert repr(flight.Action("foo", b"")) == action_repr
    assert repr(flight.ActionType("foo", "bar")) == action_type_repr
    assert repr(flight.BasicAuth("user", "pass")) == basic_auth_repr
    assert repr(flight.FlightDescriptor.for_command("foo")) == descriptor_repr
    assert repr(flight.FlightEndpoint(b"foo", [])) == endpoint_repr
    info = flight.FlightInfo(
        pa.schema([]), flight.FlightDescriptor.for_path(), [], -1, -1)
    assert repr(info) == info_repr
    assert repr(flight.Location("grpc+tcp://localhost:1234")) == location_repr
    assert repr(flight.Result(b"foo")) == result_repr
    assert repr(flight.SchemaResult(pa.schema([]))) == schema_result_repr
    assert repr(flight.SchemaResult(pa.schema([("int", "int64")]))) == \
        "<pyarrow.flight.SchemaResult schema=(int: int64)>"
    assert repr(flight.Ticket(b"foo")) == ticket_repr

    with pytest.raises(TypeError):
        flight.Action("foo", None)


def test_eq():
    items = [
        lambda: (flight.Action("foo", b""), flight.Action("foo", b"bar")),
        lambda: (flight.ActionType("foo", "bar"),
                 flight.ActionType("foo", "baz")),
        lambda: (flight.BasicAuth("user", "pass"),
                 flight.BasicAuth("user2", "pass")),
        lambda: (flight.FlightDescriptor.for_command("foo"),
                 flight.FlightDescriptor.for_path("foo")),
        lambda: (flight.FlightEndpoint(b"foo", []),
                 flight.FlightEndpoint(b"", [])),
        lambda: (
            flight.FlightInfo(
                pa.schema([]),
                flight.FlightDescriptor.for_path(), [], -1, -1),
            flight.FlightInfo(
                pa.schema([]),
                flight.FlightDescriptor.for_command(b"foo"), [], -1, 42)),
        lambda: (flight.Location("grpc+tcp://localhost:1234"),
                 flight.Location("grpc+tls://localhost:1234")),
        lambda: (flight.Result(b"foo"), flight.Result(b"bar")),
        lambda: (flight.SchemaResult(pa.schema([])),
                 flight.SchemaResult(pa.schema([("ints", pa.int64())]))),
        lambda: (flight.Ticket(b""), flight.Ticket(b"foo")),
    ]

    for gen in items:
        lhs1, rhs1 = gen()
        lhs2, rhs2 = gen()
        assert lhs1 == lhs2
        assert rhs1 == rhs2
        assert lhs1 != rhs1


def test_flight_server_location_argument():
    locations = [
        None,
        'grpc://localhost:0',
        ('localhost', find_free_port()),
    ]
    for location in locations:
        with FlightServerBase(location) as server:
            assert isinstance(server, FlightServerBase)


def test_server_exit_reraises_exception():
    with pytest.raises(ValueError):
        with FlightServerBase():
            raise ValueError()


@pytest.mark.threading
@pytest.mark.slow
def test_client_wait_for_available():
    location = ('localhost', find_free_port())
    server = None

    def serve():
        global server
        time.sleep(0.5)
        server = FlightServerBase(location)
        server.serve()

    with FlightClient(location) as client:
        thread = threading.Thread(target=serve, daemon=True)
        thread.start()

        started = time.time()
        client.wait_for_available(timeout=5)
        elapsed = time.time() - started
        assert elapsed >= 0.5


def test_flight_list_flights():
    """Try a simple list_flights call."""
    with ConstantFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        assert list(client.list_flights()) == []
        flights = client.list_flights(ConstantFlightServer.CRITERIA)
        assert len(list(flights)) == 1


def test_flight_client_close():
    with ConstantFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        assert list(client.list_flights()) == []
        client.close()
        client.close()  # Idempotent
        with pytest.raises(pa.ArrowInvalid):
            list(client.list_flights())


def test_flight_do_get_ints():
    """Try a simple do_get call."""
    table = simple_ints_table()

    with ConstantFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)

    options = pa.ipc.IpcWriteOptions(
        metadata_version=pa.ipc.MetadataVersion.V4)
    with ConstantFlightServer(options=options) as server, \
            flight.connect(('localhost', server.port)) as client:
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)

        # Also test via RecordBatchReader interface
        data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all()
        assert data.equals(table)

    with pytest.raises(flight.FlightServerError,
                       match="expected IpcWriteOptions, got <class 'int'>"):
        with ConstantFlightServer(options=42) as server, \
                flight.connect(('localhost', server.port)) as client:
            data = client.do_get(flight.Ticket(b'ints')).read_all()


@pytest.mark.pandas
def test_do_get_ints_pandas():
    """Try a simple do_get call."""
    table = simple_ints_table()

    with ConstantFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        data = client.do_get(flight.Ticket(b'ints')).read_pandas()
        assert list(data['some_ints']) == table.column(0).to_pylist()


def test_flight_do_get_dicts():
    table = simple_dicts_table()

    with ConstantFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        data = client.do_get(flight.Ticket(b'dicts')).read_all()
        assert data.equals(table)


def test_flight_do_get_ticket():
    """Make sure Tickets get passed to the server."""
    data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
    table = pa.Table.from_arrays(data1, names=['a'])
    with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server, \
            flight.connect(('localhost', server.port)) as client:
        data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
        assert data.equals(table)


def test_flight_get_info():
    """Make sure FlightEndpoint accepts string and object URIs."""
    with GetInfoFlightServer() as server:
        client = FlightClient(('localhost', server.port))
        info = client.get_flight_info(flight.FlightDescriptor.for_command(b''))
        assert info.total_records == -1
        assert info.total_bytes == -1
        assert info.schema == pa.schema([('a', pa.int32())])
        assert len(info.endpoints) == 2
        assert len(info.endpoints[0].locations) == 1
        assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
        assert info.endpoints[1].locations[0] == \
            flight.Location.for_grpc_tcp('localhost', 5005)


def test_flight_get_schema():
    """Make sure GetSchema returns correct schema."""
    with GetInfoFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        info = client.get_schema(flight.FlightDescriptor.for_command(b''))
        assert info.schema == pa.schema([('a', pa.int32())])


def test_list_actions():
    """Make sure the return type of ListActions is validated."""
    # ARROW-6392
    with ListActionsErrorFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        with pytest.raises(
                flight.FlightServerError,
                match=("Results of list_actions must be "
                       "ActionType or tuple")
        ):
            list(client.list_actions())

    with ListActionsFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        assert list(client.list_actions()) == \
            ListActionsFlightServer.expected_actions()


class ConvenienceServer(FlightServerBase):
    """
    Server for testing various implementation conveniences (auto-boxing, etc.)
    """

    @property
    def simple_action_results(self):
        return [b'foo', b'bar', b'baz']

    def do_action(self, context, action):
        if action.type == 'simple-action':
            return self.simple_action_results
        elif action.type == 'echo':
            return [action.body]
        elif action.type == 'bad-action':
            return ['foo']
        elif action.type == 'arrow-exception':
            raise pa.ArrowMemoryError()
        elif action.type == 'forever':
            def gen():
                while not context.is_cancelled():
                    yield b'foo'
            return gen()


def test_do_action_result_convenience():
    with ConvenienceServer() as server, \
            FlightClient(('localhost', server.port)) as client:

        # do_action as action type without body
        results = [x.body for x in client.do_action('simple-action')]
        assert results == server.simple_action_results

        # do_action with tuple of type and body
        body = b'the-body'
        results = [x.body for x in client.do_action(('echo', body))]
        assert results == [body]


def test_nicer_server_exceptions():
    with ConvenienceServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        with pytest.raises(flight.FlightServerError,
                           match="a bytes-like object is required"):
            list(client.do_action('bad-action'))
        # While Flight/C++ sends across the original status code, it
        # doesn't get mapped to the equivalent code here, since we
        # want to be able to distinguish between client- and server-
        # side errors.
        with pytest.raises(flight.FlightServerError,
                           match="ArrowMemoryError"):
            list(client.do_action('arrow-exception'))


def test_get_port():
    """Make sure port() works."""
    server = GetInfoFlightServer("grpc://localhost:0")
    try:
        assert server.port > 0
    finally:
        server.shutdown()


@pytest.mark.skipif(os.name == 'nt',
                    reason="Unix sockets can't be tested on Windows")
def test_flight_domain_socket():
    """Try a simple do_get call over a Unix domain socket."""
    with tempfile.NamedTemporaryFile() as sock:
        sock.close()
        location = flight.Location.for_grpc_unix(sock.name)
        with ConstantFlightServer(location=location), \
                FlightClient(location) as client:

            reader = client.do_get(flight.Ticket(b'ints'))
            table = simple_ints_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)

            reader = client.do_get(flight.Ticket(b'dicts'))
            table = simple_dicts_table()
            assert reader.schema.equals(table.schema)
            data = reader.read_all()
            assert data.equals(table)


@pytest.mark.slow
def test_flight_large_message():
    """Try sending/receiving a large message via Flight.

    See ARROW-4421: by default, gRPC won't allow us to send messages >
    4MiB in size.
    """
    data = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024 * 1024))
    ], names=['a'])

    with EchoFlightServer(expected_schema=data.schema) as server, \
            FlightClient(('localhost', server.port)) as client:
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        # Write a single giant chunk
        writer.write_table(data, 10 * 1024 * 1024)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)


def test_flight_generator_stream():
    """Try downloading a flight of RecordBatches in a GeneratorStream."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024))
    ], names=['a'])

    with EchoStreamFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                  data.schema)
        writer.write_table(data)
        writer.close()
        result = client.do_get(flight.Ticket(b'')).read_all()
        assert result.equals(data)


def test_flight_invalid_generator_stream():
    """Try streaming data with mismatched schemas."""
    with InvalidStreamFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        with pytest.raises(pa.ArrowException):
            client.do_get(flight.Ticket(b'')).read_all()


def test_timeout_fires():
    """Make sure timeouts fire on slow requests."""
    # Do this in a separate thread so that if it fails, we don't hang
    # the entire test process
    with SlowFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        action = flight.Action("", b"")
        options = flight.FlightCallOptions(timeout=0.2)
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightTimedOutError):
            list(client.do_action(action, options=options))


def test_timeout_passes():
    """Make sure timeouts do not fire on fast requests."""
    with ConstantFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        options = flight.FlightCallOptions(timeout=5.0)
        client.do_get(flight.Ticket(b'ints'), options=options).read_all()


def test_read_options():
    """Make sure ReadOptions can be used."""
    expected = pa.Table.from_arrays([pa.array([1, 2, 3, 4])], names=["b"])
    with ConstantFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        options = flight.FlightCallOptions(
            read_options=IpcReadOptions(included_fields=[1]))
        response1 = client.do_get(flight.Ticket(
            b'multi'), options=options).read_all()
        response2 = client.do_get(flight.Ticket(b'multi')).read_all()

        assert response2.num_columns == 2
        assert response1.num_columns == 1
        assert response1 == expected
        assert response2 == multiple_column_table()


basic_auth_handler = HttpBasicServerAuthHandler(creds={
    b"test": b"p4ssw0rd",
})

token_auth_handler = TokenServerAuthHandler(creds={
    b"test": b"p4ssw0rd",
})


@pytest.mark.slow
def test_http_basic_unauth():
    """Test that auth fails when not authenticated."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
            FlightClient(('localhost', server.port)) as client:
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*unauthenticated.*"):
            list(client.do_action(action))


@pytest.mark.skipif(os.name == 'nt',
                    reason="ARROW-10013: gRPC on Windows corrupts peer()")
def test_http_basic_auth():
    """Test a Python implementation of HTTP basic authentication."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
            FlightClient(('localhost', server.port)) as client:
        action = flight.Action("who-am-i", b"")
        client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
        results = client.do_action(action)
        identity = next(results)
        assert identity.body.to_pybytes() == b'test'
        peer_address = next(results)
        assert peer_address.body.to_pybytes() != b''


def test_http_basic_auth_invalid_password():
    """Test that auth fails with the wrong password."""
    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
            FlightClient(('localhost', server.port)) as client:
        action = flight.Action("who-am-i", b"")
        with pytest.raises(flight.FlightUnauthenticatedError,
                           match=".*wrong password.*"):
            client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
            next(client.do_action(action))


def test_token_auth():
    """Test an auth mechanism that uses a handshake."""
    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
            FlightClient(('localhost', server.port)) as client:
        action = flight.Action("who-am-i", b"")
        client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
        identity = next(client.do_action(action))
        assert identity.body.to_pybytes() == b'test'


def test_token_auth_invalid():
    """Test an auth mechanism that uses a handshake."""
    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
            FlightClient(('localhost', server.port)) as client:
        with pytest.raises(flight.FlightUnauthenticatedError):
            client.authenticate(TokenClientAuthHandler('test', 'wrong'))


header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
no_op_auth_handler = NoopAuthHandler()


def test_authenticate_basic_token():
    """Test authenticate_basic_token with bearer token and auth headers."""
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        token_pair = client.authenticate_basic_token(b'test', b'password')
        assert token_pair[0] == b'authorization'
        assert token_pair[1] == b'Bearer token1234'


def test_authenticate_basic_token_invalid_password():
    """Test authenticate_basic_token with an invalid password."""
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        with pytest.raises(flight.FlightUnauthenticatedError):
            client.authenticate_basic_token(b'test', b'badpassword')


def test_authenticate_basic_token_and_action():
    """Test authenticate_basic_token and doAction after authentication."""
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        token_pair = client.authenticate_basic_token(b'test', b'password')
        assert token_pair[0] == b'authorization'
        assert token_pair[1] == b'Bearer token1234'
        options = flight.FlightCallOptions(headers=[token_pair])
        result = list(client.do_action(
            action=flight.Action('test-action', b''), options=options))
        assert result[0].body.to_pybytes() == b'token1234'


def test_authenticate_basic_token_with_client_middleware():
    """Test authenticate_basic_token with client middleware
       to intercept authorization header returned by the
       HTTP header auth enabled server.
    """
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server:
        client_auth_middleware = ClientHeaderAuthMiddlewareFactory()
        client = FlightClient(
            ('localhost', server.port),
            middleware=[client_auth_middleware]
        )
        encoded_credentials = base64.b64encode(b'test:password')
        options = flight.FlightCallOptions(headers=[
            (b'authorization', b'Basic ' + encoded_credentials)
        ])
        result = list(client.do_action(
            action=flight.Action('test-action', b''), options=options))
        assert result[0].body.to_pybytes() == b'token1234'
        assert client_auth_middleware.call_credential[0] == b'authorization'
        assert client_auth_middleware.call_credential[1] == \
            b'Bearer ' + b'token1234'
        result2 = list(client.do_action(
            action=flight.Action('test-action', b''), options=options))
        assert result2[0].body.to_pybytes() == b'token1234'
        assert client_auth_middleware.call_credential[0] == b'authorization'
        assert client_auth_middleware.call_credential[1] == \
            b'Bearer ' + b'token1234'
        client.close()


def test_arbitrary_headers_in_flight_call_options():
    """Test passing multiple arbitrary headers to the middleware."""
    with ArbitraryHeadersFlightServer(
        auth_handler=no_op_auth_handler,
        middleware={
            "auth": HeaderAuthServerMiddlewareFactory(),
            "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
        }) as server, \
            FlightClient(('localhost', server.port)) as client:
        token_pair = client.authenticate_basic_token(b'test', b'password')
        assert token_pair[0] == b'authorization'
        assert token_pair[1] == b'Bearer token1234'
        options = flight.FlightCallOptions(headers=[
            token_pair,
            (b'test-header-1', b'value1'),
            (b'test-header-2', b'value2')
        ])
        result = list(client.do_action(flight.Action(
            "test-action", b""), options=options))
        assert result[0].body.to_pybytes() == b'value1'
        assert result[1].body.to_pybytes() == b'value2'


def test_location_invalid():
    """Test constructing invalid URIs."""
    with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
        flight.connect("%")

    with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
        ConstantFlightServer("%")


def test_location_unknown_scheme():
    """Test creating locations for unknown schemes."""
    assert flight.Location("s3://foo").uri == b"s3://foo"
    assert flight.Location("https://example.com/bar.parquet").uri == \
        b"https://example.com/bar.parquet"


@pytest.mark.slow
@pytest.mark.requires_testing_data
def test_tls_fails():
    """Make sure clients cannot connect when cert verification fails."""
    certs = example_tls_certs()

    # Ensure client doesn't connect when certificate verification
    # fails (this is a slow test since gRPC does retry a few times)
    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
            FlightClient("grpc+tls://localhost:" + str(s.port)) as client:
        # gRPC error messages change based on version, so don't look
        # for a particular error
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints')).read_all()


@pytest.mark.requires_testing_data
def test_tls_do_get():
    """Try a simple do_get call over TLS."""
    table = simple_ints_table()
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
        FlightClient(('localhost', s.port),
                     tls_root_certs=certs["root_cert"]) as client:
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)


@pytest.mark.requires_testing_data
def test_tls_disable_server_verification():
    """Try a simple do_get call over TLS with server verification disabled."""
    table = simple_ints_table()
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        try:
            client = FlightClient(('localhost', s.port),
                                  disable_server_verification=True)
        except NotImplementedError:
            pytest.skip('disable_server_verification feature is not available')
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)
        client.close()


@pytest.mark.requires_testing_data
def test_tls_override_hostname():
    """Check that incorrectly overriding the hostname fails."""
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
        flight.connect(('localhost', s.port),
                       tls_root_certs=certs["root_cert"],
                       override_hostname="fakehostname") as client:
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints'))


def test_flight_do_get_metadata():
    """Try a simple do_get call with metadata."""
    data = [
        pa.array([-10, -5, 0, 5, 10])
    ]
    table = pa.Table.from_arrays(data, names=['a'])

    batches = []
    with MetadataFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b''))
        idx = 0
        for batch, metadata in reader:
            batches.append(batch)
            server_idx, = struct.unpack('<i', metadata.to_pybytes())
            assert idx == server_idx
            idx += 1
        data = pa.Table.from_batches(batches)
        assert data.equals(table)


def test_flight_metadata_record_batch_reader_iterator():
    """Verify the iterator interface works as expected."""
    batches1 = []
    batches2 = []

    with MetadataFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b''))
        idx = 0
        while True:
            try:
                batch, metadata = reader.read_chunk()
                batches1.append(batch)
                server_idx, = struct.unpack('<i', metadata.to_pybytes())
                assert idx == server_idx
                idx += 1
            except StopIteration:
                break

    with MetadataFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b''))
        idx = 0
        for batch, metadata in reader:
            batches2.append(batch)
            server_idx, = struct.unpack('<i', metadata.to_pybytes())
            assert idx == server_idx
            idx += 1

    assert batches1 == batches2


def test_flight_do_get_metadata_v4():
    """Try a simple do_get call with V4 metadata version."""
    table = pa.Table.from_arrays(
        [pa.array([-10, -5, 0, 5, 10])], names=['a'])
    options = pa.ipc.IpcWriteOptions(
        metadata_version=pa.ipc.MetadataVersion.V4)
    with MetadataFlightServer(options=options) as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b''))
        data = reader.read_all()
        assert data.equals(table)


def test_flight_do_put_metadata():
    """Try a simple do_put call with metadata."""
    data = [
        pa.array([-10, -5, 0, 5, 10])
    ]
    table = pa.Table.from_arrays(data, names=['a'])

    with MetadataFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        writer, metadata_reader = client.do_put(
            flight.FlightDescriptor.for_path(''),
            table.schema)
        with writer:
            for idx, batch in enumerate(table.to_batches(max_chunksize=1)):
                metadata = struct.pack('<i', idx)
                writer.write_with_metadata(batch, metadata)
                buf = metadata_reader.read()
                assert buf is not None
                server_idx, = struct.unpack('<i', buf.to_pybytes())
                assert idx == server_idx


@pytest.mark.numpy
def test_flight_do_put_limit():
    """Try a simple do_put call with a size limit."""
    large_batch = pa.RecordBatch.from_arrays([
        pa.array(np.ones(768, dtype=np.int64())),
    ], names=['a'])

    with EchoFlightServer() as server, \
        FlightClient(('localhost', server.port),
                     write_size_limit_bytes=4096) as client:
        writer, metadata_reader = client.do_put(
            flight.FlightDescriptor.for_path(''),
            large_batch.schema)
        with writer:
            with pytest.raises(flight.FlightWriteSizeExceededError,
                               match="exceeded soft limit") as excinfo:
                writer.write_batch(large_batch)
            assert excinfo.value.limit == 4096
            smaller_batches = [
                large_batch.slice(0, 384),
                large_batch.slice(384),
            ]
            for batch in smaller_batches:
                writer.write_batch(batch)
        expected = pa.Table.from_batches([large_batch])
        actual = client.do_get(flight.Ticket(b'')).read_all()
        assert expected == actual


@pytest.mark.slow
def test_cancel_do_get():
    """Test canceling a DoGet operation on the client side."""
    with ConstantFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b'ints'))
        reader.cancel()
        with pytest.raises(flight.FlightCancelledError,
                           match="(?i).*cancel.*"):
            reader.read_chunk()


@pytest.mark.threading
@pytest.mark.slow
def test_cancel_do_get_threaded():
    """Test canceling a DoGet operation from another thread."""
    with SlowFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        reader = client.do_get(flight.Ticket(b'ints'))

        read_first_message = threading.Event()
        stream_canceled = threading.Event()
        result_lock = threading.Lock()
        raised_proper_exception = threading.Event()

        def block_read():
            reader.read_chunk()
            read_first_message.set()
            stream_canceled.wait(timeout=5)
            try:
                reader.read_chunk()
            except flight.FlightCancelledError:
                with result_lock:
                    raised_proper_exception.set()

        thread = threading.Thread(target=block_read, daemon=True)
        thread.start()
        read_first_message.wait(timeout=5)
        reader.cancel()
        stream_canceled.set()
        thread.join(timeout=1)

        with result_lock:
            assert raised_proper_exception.is_set()


def test_streaming_do_action():
    with ConvenienceServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        results = client.do_action(flight.Action('forever', b''))
        assert next(results).body == b'foo'
        # Implicit cancel when destructed
        del results


def test_roundtrip_types():
    """Make sure serializable types round-trip."""
    action = flight.Action("action1", b"action1-body")
    assert action == flight.Action.deserialize(action.serialize())

    ticket = flight.Ticket("foo")
    assert ticket == flight.Ticket.deserialize(ticket.serialize())

    result = flight.Result(b"result1")
    assert result == flight.Result.deserialize(result.serialize())

    basic_auth = flight.BasicAuth("username1", "password1")
    assert basic_auth == flight.BasicAuth.deserialize(basic_auth.serialize())

    schema_result = flight.SchemaResult(pa.schema([('a', pa.int32())]))
    assert schema_result == flight.SchemaResult.deserialize(
        schema_result.serialize())

    desc = flight.FlightDescriptor.for_command("test")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow")
    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())

    info = flight.FlightInfo(
        pa.schema([('a', pa.int32())]),
        desc,
        [
            flight.FlightEndpoint(b'', ['grpc://test']),
            flight.FlightEndpoint(
                b'',
                [flight.Location.for_grpc_tcp('localhost', 5005)],
            ),
        ],
        -1,
        -1,
    )
    info2 = flight.FlightInfo.deserialize(info.serialize())
    assert info.schema == info2.schema
    assert info.descriptor == info2.descriptor
    assert info.total_bytes == info2.total_bytes
    assert info.total_records == info2.total_records
    assert info.endpoints == info2.endpoints

    endpoint = flight.FlightEndpoint(
        ticket,
        ['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)]
    )
    assert endpoint == flight.FlightEndpoint.deserialize(endpoint.serialize())


def test_roundtrip_errors():
    """Ensure that Flight errors propagate from server to client."""
    with ErrorFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:

        for arg, exc_type in ErrorFlightServer.error_cases().items():
            with pytest.raises(exc_type, match=".*foo.*"):
                list(client.do_action(flight.Action(arg, b"")))
        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
            list(client.list_flights())

        data = [pa.array([-10, -5, 0, 5, 10])]
        table = pa.Table.from_arrays(data, names=['a'])

        exceptions = {
            'internal': flight.FlightInternalError,
            'timedout': flight.FlightTimedOutError,
            'cancel': flight.FlightCancelledError,
            'unauthenticated': flight.FlightUnauthenticatedError,
            'unauthorized': flight.FlightUnauthorizedError,
        }

        for command, exception in exceptions.items():

            with pytest.raises(exception, match=".*foo.*"):
                writer, reader = client.do_put(
                    flight.FlightDescriptor.for_command(command),
                    table.schema)
                writer.write_table(table)
                writer.close()

            with pytest.raises(exception, match=".*foo.*"):
                writer, reader = client.do_put(
                    flight.FlightDescriptor.for_command(command),
                    table.schema)
                writer.close()


def test_do_put_independent_read_write():
    """Ensure that separate threads can read/write on a DoPut."""
    # ARROW-6063: previously this would cause gRPC to abort when the
    # writer was closed (due to simultaneous reads), or would hang
    # forever.
    data = [
        pa.array([-10, -5, 0, 5, 10])
    ]
    table = pa.Table.from_arrays(data, names=['a'])

    with MetadataFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        writer, metadata_reader = client.do_put(
            flight.FlightDescriptor.for_path(''),
            table.schema)

        count = [0]

        def _reader_thread():
            while metadata_reader.read() is not None:
                count[0] += 1

        thread = threading.Thread(target=_reader_thread)
        thread.start()

        batches = table.to_batches(max_chunksize=1)
        with writer:
            for idx, batch in enumerate(batches):
                metadata = struct.pack('<i', idx)
                writer.write_with_metadata(batch, metadata)
            # Causes the server to stop writing and end the call
            writer.done_writing()
            # Thus reader thread will break out of loop
            thread.join()
        # writer.close() won't segfault since reader thread has
        # stopped
        assert count[0] == len(batches)


def test_server_middleware_same_thread():
    """Ensure that server middleware run on the same thread as the RPC."""
    with HeaderFlightServer(middleware={
        "test": HeaderServerMiddlewareFactory(),
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        results = list(client.do_action(flight.Action(b"test", b"")))
        assert len(results) == 1
        value = results[0].body.to_pybytes()
        assert b"right value" == value


def test_middleware_reject():
    """Test rejecting an RPC with server middleware."""
    with HeaderFlightServer(middleware={
        "test": SelectiveAuthServerMiddlewareFactory(),
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        # The middleware allows this through without auth.
        with pytest.raises(pa.ArrowNotImplementedError):
            list(client.list_actions())

        # But not anything else.
        with pytest.raises(flight.FlightUnauthenticatedError):
            list(client.do_action(flight.Action(b"", b"")))

        client = FlightClient(
            ('localhost', server.port),
            middleware=[SelectiveAuthClientMiddlewareFactory()]
        )
        response = next(client.do_action(flight.Action(b"", b"")))
        assert b"password" == response.body.to_pybytes()


def test_middleware_mapping():
    """Test that middleware records methods correctly."""
    server_middleware = RecordingServerMiddlewareFactory()
    client_middleware = RecordingClientMiddlewareFactory()
    with FlightServerBase(middleware={"test": server_middleware}) as server, \
        FlightClient(
            ('localhost', server.port),
            middleware=[client_middleware]
    ) as client:

        descriptor = flight.FlightDescriptor.for_command(b"")
        with pytest.raises(NotImplementedError):
            list(client.list_flights())
        with pytest.raises(NotImplementedError):
            client.get_flight_info(descriptor)
        with pytest.raises(NotImplementedError):
            client.get_schema(descriptor)
        with pytest.raises(NotImplementedError):
            client.do_get(flight.Ticket(b""))
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_put(descriptor, pa.schema([]))
            writer.close()
        with pytest.raises(NotImplementedError):
            list(client.do_action(flight.Action(b"", b"")))
        with pytest.raises(NotImplementedError):
            list(client.list_actions())
        with pytest.raises(NotImplementedError):
            writer, _ = client.do_exchange(descriptor)
            writer.close()

        expected = [
            flight.FlightMethod.LIST_FLIGHTS,
            flight.FlightMethod.GET_FLIGHT_INFO,
            flight.FlightMethod.GET_SCHEMA,
            flight.FlightMethod.DO_GET,
            flight.FlightMethod.DO_PUT,
            flight.FlightMethod.DO_ACTION,
            flight.FlightMethod.LIST_ACTIONS,
            flight.FlightMethod.DO_EXCHANGE,
        ]
        assert server_middleware.methods == expected
        assert client_middleware.methods == expected


def test_extra_info():
    with ErrorFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        try:
            list(client.do_action(flight.Action("protobuf", b"")))
            assert False
        except flight.FlightUnauthorizedError as e:
            assert e.extra_info is not None
            ei = e.extra_info
            assert ei == b'this is an error message'


@pytest.mark.requires_testing_data
def test_mtls():
    """Test mutual TLS (mTLS) with gRPC."""
    certs = example_tls_certs()
    table = simple_ints_table()

    with ConstantFlightServer(
            tls_certificates=[certs["certificates"][0]],
            verify_client=True,
            root_certificates=certs["root_cert"]) as s, \
        FlightClient(
            ('localhost', s.port),
            tls_root_certs=certs["root_cert"],
            cert_chain=certs["certificates"][0].cert,
            private_key=certs["certificates"][0].key) as client:
        data = client.do_get(flight.Ticket(b'ints')).read_all()
        assert data.equals(table)


def test_doexchange_get():
    """Emulate DoGet with DoExchange."""
    expected = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024))
    ], names=["a"])

    with ExchangeFlightServer() as server, \
            FlightClient(("localhost", server.port)) as client:
        descriptor = flight.FlightDescriptor.for_command(b"get")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            table = reader.read_all()
        assert expected == table


def test_doexchange_put():
    """Emulate DoPut with DoExchange."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024))
    ], names=["a"])
    batches = data.to_batches(max_chunksize=512)

    with ExchangeFlightServer() as server, \
            FlightClient(("localhost", server.port)) as client:
        descriptor = flight.FlightDescriptor.for_command(b"put")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            writer.begin(data.schema)
            for batch in batches:
                writer.write_batch(batch)
            writer.done_writing()
            chunk = reader.read_chunk()
            assert chunk.data is None
            expected_buf = str(len(batches)).encode("utf-8")
            assert chunk.app_metadata == expected_buf


def test_doexchange_echo():
    """Try a DoExchange echo server."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024))
    ], names=["a"])
    batches = data.to_batches(max_chunksize=512)

    with ExchangeFlightServer() as server, \
            FlightClient(("localhost", server.port)) as client:
        descriptor = flight.FlightDescriptor.for_command(b"echo")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            # Read/write metadata before starting data.
            for i in range(10):
                buf = str(i).encode("utf-8")
                writer.write_metadata(buf)
                chunk = reader.read_chunk()
                assert chunk.data is None
                assert chunk.app_metadata == buf

            # Now write data without metadata.
            writer.begin(data.schema)
            for batch in batches:
                writer.write_batch(batch)
                assert reader.schema == data.schema
                chunk = reader.read_chunk()
                assert chunk.data == batch
                assert chunk.app_metadata is None

            # And write data with metadata.
            for i, batch in enumerate(batches):
                buf = str(i).encode("utf-8")
                writer.write_with_metadata(batch, buf)
                chunk = reader.read_chunk()
                assert chunk.data == batch
                assert chunk.app_metadata == buf


def test_doexchange_echo_v4():
    """Try a DoExchange echo server using the V4 metadata version."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 10 * 1024))
    ], names=["a"])
    batches = data.to_batches(max_chunksize=512)

    options = pa.ipc.IpcWriteOptions(
        metadata_version=pa.ipc.MetadataVersion.V4)
    with ExchangeFlightServer(options=options) as server, \
            FlightClient(("localhost", server.port)) as client:
        descriptor = flight.FlightDescriptor.for_command(b"echo")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            # Now write data without metadata.
            writer.begin(data.schema, options=options)
            for batch in batches:
                writer.write_batch(batch)
                assert reader.schema == data.schema
                chunk = reader.read_chunk()
                assert chunk.data == batch
                assert chunk.app_metadata is None


def test_doexchange_transform():
    """Transform a table with a service."""
    data = pa.Table.from_arrays([
        pa.array(range(0, 1024)),
        pa.array(range(1, 1025)),
        pa.array(range(2, 1026)),
    ], names=["a", "b", "c"])
    expected = pa.Table.from_arrays([
        pa.array(range(3, 1024 * 3 + 3, 3)),
    ], names=["sum"])

    with ExchangeFlightServer() as server, \
            FlightClient(("localhost", server.port)) as client:
        descriptor = flight.FlightDescriptor.for_command(b"transform")
        writer, reader = client.do_exchange(descriptor)
        with writer:
            writer.begin(data.schema)
            writer.write_table(data)
            writer.done_writing()
            table = reader.read_all()
        assert expected == table


def test_middleware_multi_header():
    """Test sending/receiving multiple (binary-valued) headers."""
    with MultiHeaderFlightServer(middleware={
        "test": MultiHeaderServerMiddlewareFactory(),
    }) as server:
        headers = MultiHeaderClientMiddlewareFactory()
        with FlightClient(
                ('localhost', server.port),
                middleware=[headers]) as client:
            response = next(client.do_action(flight.Action(b"", b"")))
            # The server echoes the headers it got back to us.
            raw_headers = response.body.to_pybytes().decode("utf-8")
            client_headers = ast.literal_eval(raw_headers)
            # Don't directly compare; gRPC may add headers like User-Agent.
            for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
                header = header.lower()
                if isinstance(header, bytes):
                    header = header.decode("ascii")
                assert client_headers.get(header) == values
                assert headers.last_headers.get(header) == values


@pytest.mark.requires_testing_data
def test_generic_options():
    """Test setting generic client options."""
    certs = example_tls_certs()

    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
        # Try setting a string argument that will make requests fail
        options = [("grpc.ssl_target_name_override", "fakehostname")]
        client = flight.connect(('localhost', s.port),
                                tls_root_certs=certs["root_cert"],
                                generic_options=options)
        with pytest.raises(flight.FlightUnavailableError):
            client.do_get(flight.Ticket(b'ints'))
        client.close()
        # Try setting an int argument that will make requests fail
        options = [("grpc.max_receive_message_length", 32)]
        client = flight.connect(('localhost', s.port),
                                tls_root_certs=certs["root_cert"],
                                generic_options=options)
        with pytest.raises((pa.ArrowInvalid, flight.FlightCancelledError)):
            client.do_get(flight.Ticket(b'ints'))
        client.close()


class CancelFlightServer(FlightServerBase):
    """A server for testing StopToken."""

    def do_get(self, context, ticket):
        schema = pa.schema([])
        rb = pa.RecordBatch.from_arrays([], schema=schema)
        return flight.GeneratorStream(schema, itertools.repeat(rb))

    def do_exchange(self, context, descriptor, reader, writer):
        schema = pa.schema([])
        rb = pa.RecordBatch.from_arrays([], schema=schema)
        writer.begin(schema)
        while not context.is_cancelled():
            writer.write_batch(rb)
            time.sleep(0.5)


@pytest.mark.threading
def test_interrupt():
    if threading.current_thread().ident != threading.main_thread().ident:
        pytest.skip("test only works from main Python thread")

    def signal_from_thread():
        time.sleep(0.5)
        signal.raise_signal(signal.SIGINT)

    exc_types = (KeyboardInterrupt, pa.ArrowCancelled)

    def test(read_all):
        try:
            try:
                t = threading.Thread(target=signal_from_thread)
                with pytest.raises(exc_types) as exc_info:
                    t.start()
                    read_all()
            finally:
                t.join()
        except KeyboardInterrupt:
            # In case KeyboardInterrupt didn't interrupt read_all
            # above, at least prevent it from stopping the test suite
            pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
        # __context__ is sometimes None
        e = exc_info.value
        assert isinstance(e, (pa.ArrowCancelled, KeyboardInterrupt)) or \
            isinstance(e.__context__, (pa.ArrowCancelled, KeyboardInterrupt))

    with CancelFlightServer() as server, \
            FlightClient(("localhost", server.port)) as client:

        reader = client.do_get(flight.Ticket(b""))
        test(reader.read_all)

        descriptor = flight.FlightDescriptor.for_command(b"echo")
        writer, reader = client.do_exchange(descriptor)
        test(reader.read_all)
        try:
            writer.close()
        except (KeyboardInterrupt, flight.FlightCancelledError):
            # Silence the Cancelled/Interrupt exception
            pass


def test_never_sends_data():
    # Regression test for ARROW-12779
    match = "application server implementation error"
    with NeverSendsDataFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        with pytest.raises(flight.FlightServerError, match=match):
            client.do_get(flight.Ticket(b'')).read_all()

        # Check that the server handler will ignore empty tables
        # up to a certain extent
        table = client.do_get(flight.Ticket(b'yield_data')).read_all()
        assert table.num_rows == 5


@pytest.mark.large_memory
@pytest.mark.slow
def test_large_descriptor():
    # Regression test for ARROW-13253. Placed here with appropriate marks
    # since some CI pipelines can't run the C++ equivalent
    large_descriptor = flight.FlightDescriptor.for_command(
        b' ' * (2 ** 31 + 1))
    with FlightServerBase() as server, \
            flight.connect(('localhost', server.port)) as client:
        with pytest.raises(OSError,
                           match="Failed to serialize Flight descriptor"):
            writer, _ = client.do_put(large_descriptor, pa.schema([]))
            writer.close()
        with pytest.raises(pa.ArrowException,
                           match="Failed to serialize Flight descriptor"):
            client.do_exchange(large_descriptor)


@pytest.mark.large_memory
@pytest.mark.slow
def test_large_metadata_client():
    # Regression test for ARROW-13253
    descriptor = flight.FlightDescriptor.for_command(b'')
    metadata = b' ' * (2 ** 31 + 1)
    with EchoFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        with pytest.raises(pa.ArrowCapacityError,
                           match="app_metadata size overflow"):
            writer, _ = client.do_put(descriptor, pa.schema([]))
            with writer:
                writer.write_metadata(metadata)
                writer.close()
        with pytest.raises(pa.ArrowCapacityError,
                           match="app_metadata size overflow"):
            writer, reader = client.do_exchange(descriptor)
            with writer:
                writer.write_metadata(metadata)

    del metadata
    with LargeMetadataFlightServer() as server, \
            flight.connect(('localhost', server.port)) as client:
        with pytest.raises(flight.FlightServerError,
                           match="app_metadata size overflow"):
            reader = client.do_get(flight.Ticket(b''))
            reader.read_all()
        with pytest.raises(pa.ArrowException,
                           match="app_metadata size overflow"):
            writer, reader = client.do_exchange(descriptor)
            with writer:
                reader.read_all()


class ActionNoneFlightServer(EchoFlightServer):
    """A server that implements a side effect to a non iterable action."""
    VALUES = []

    def do_action(self, context, action):
        if action.type == "get_value":
            return [json.dumps(self.VALUES).encode('utf-8')]
        elif action.type == "append":
            self.VALUES.append(True)
            return None
        raise NotImplementedError


def test_none_action_side_effect():
    """Ensure that actions are executed even when we don't consume iterator.

    See https://issues.apache.org/jira/browse/ARROW-14255
    """

    with ActionNoneFlightServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        client.do_action(flight.Action("append", b""))
        r = client.do_action(flight.Action("get_value", b""))
        assert json.loads(next(r).body.to_pybytes()) == [True]


@pytest.mark.slow  # Takes a while for gRPC to "realize" writes fail
def test_write_error_propagation():
    """
    Ensure that exceptions during writing preserve error context.

    See https://issues.apache.org/jira/browse/ARROW-16592.
    """
    expected_message = "foo"
    expected_info = b"bar"
    exc = flight.FlightCancelledError(
        expected_message, extra_info=expected_info)
    descriptor = flight.FlightDescriptor.for_command(b"")
    schema = pa.schema([("int64", pa.int64())])

    class FailServer(flight.FlightServerBase):
        def do_put(self, context, descriptor, reader, writer):
            raise exc

        def do_exchange(self, context, descriptor, reader, writer):
            raise exc

    with FailServer() as server, \
            FlightClient(('localhost', server.port)) as client:
        # DoPut
        writer, reader = client.do_put(descriptor, schema)

        # Set a concurrent reader - ensure this doesn't block the
        # writer side from calling Close()
        def _reader():
            try:
                while True:
                    reader.read()
            except flight.FlightError:
                return

        thread = threading.Thread(target=_reader, daemon=True)
        thread.start()

        with pytest.raises(flight.FlightCancelledError) as exc_info:
            while True:
                writer.write_batch(pa.record_batch([[1]], schema=schema))
        assert exc_info.value.extra_info == expected_info

        with pytest.raises(flight.FlightCancelledError) as exc_info:
            writer.close()
        assert exc_info.value.extra_info == expected_info
        thread.join()

        # DoExchange
        writer, reader = client.do_exchange(descriptor)

        def _reader():
            try:
                while True:
                    reader.read_chunk()
            except flight.FlightError:
                return

        thread = threading.Thread(target=_reader, daemon=True)
        thread.start()
        with pytest.raises(flight.FlightCancelledError) as exc_info:
            while True:
                writer.write_metadata(b" ")
        assert exc_info.value.extra_info == expected_info

        with pytest.raises(flight.FlightCancelledError) as exc_info:
            writer.close()
        assert exc_info.value.extra_info == expected_info
        thread.join()


def test_interpreter_shutdown():
    """
    Ensure that the gRPC server is stopped at interpreter shutdown.

    See https://issues.apache.org/jira/browse/ARROW-16597.
    """
    util.invoke_script("arrow_16597.py")


class TracingFlightServer(FlightServerBase):
    """A server that echoes back trace context values."""

    def do_action(self, context, action):
        trace_context = context.get_middleware("tracing").trace_context
        # Don't turn this method into a generator since then
        # trace_context will be evaluated after we've exited the scope
        # of the OTel span (and so the value we want won't be present)
        return ((f"{key}: {value}").encode("utf-8")
                for (key, value) in trace_context.items())


def test_tracing():
    with TracingFlightServer(middleware={
            "tracing": flight.TracingServerMiddlewareFactory(),
    }) as server, \
            FlightClient(('localhost', server.port)) as client:
        # We can't tell if Arrow was built with OpenTelemetry support,
        # so we can't count on any particular values being there; we
        # can only ensure things don't blow up either way.
        options = flight.FlightCallOptions(headers=[
            # Pretend we have an OTel implementation
            (b"traceparent", b"00-000ff00f00f0ff000f0f00ff0f00fff0-"
                             b"000f0000f0f00000-00"),
            (b"tracestate", b""),
        ])
        for value in client.do_action((b"", b""), options=options):
            pass


def test_do_put_does_not_crash_when_schema_is_none():
    client = FlightClient('grpc+tls://localhost:9643',
                          disable_server_verification=True)
    msg = ("Argument 'schema' has incorrect type "
           r"\(expected pyarrow.lib.Schema, got NoneType\)")
    with pytest.raises(TypeError, match=msg):
        client.do_put(flight.FlightDescriptor.for_command('foo'),
                      schema=None)


def test_headers_trailers():
    """Ensure that server-sent headers/trailers make it through."""

    class HeadersTrailersFlightServer(FlightServerBase):
        def get_flight_info(self, context, descriptor):
            context.add_header("x-header", "header-value")
            context.add_header("x-header-bin", "header\x01value")
            context.add_trailer("x-trailer", "trailer-value")
            context.add_trailer("x-trailer-bin", "trailer\x01value")
            return flight.FlightInfo(
                pa.schema([]),
                descriptor,
                [],
                -1, -1
            )

    class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory):
        def __init__(self):
            self.headers = []

        def start_call(self, info):
            return HeadersTrailersMiddleware(self)

    class HeadersTrailersMiddleware(ClientMiddleware):
        def __init__(self, factory):
            self.factory = factory

        def received_headers(self, headers):
            for key, values in headers.items():
                for value in values:
                    self.factory.headers.append((key, value))

    factory = HeadersTrailersMiddlewareFactory()
    with HeadersTrailersFlightServer() as server, \
            FlightClient(("localhost", server.port), middleware=[factory]) as client:
        client.get_flight_info(flight.FlightDescriptor.for_path(""))
        assert ("x-header", "header-value") in factory.headers
        assert ("x-header-bin", b"header\x01value") in factory.headers
        assert ("x-trailer", "trailer-value") in factory.headers
        assert ("x-trailer-bin", b"trailer\x01value") in factory.headers
