/*
 * Copyright 2025 Bloomberg Finance LP
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <buildboxcommon_commandline.h>
#include <buildboxcommon_connectionoptions.h>
#include <buildboxcommon_fileutils.h>
#include <buildboxcommon_grpcclient.h>
#include <buildboxcommon_logging.h>
#include <buildboxcommon_logging_commandline.h>
#include <buildboxcommon_protos.h>
#include <buildboxcommon_systemutils.h>

#include "casd_wrap_cmdlinespec.h"

#include <algorithm>
#include <filesystem>
#include <signal.h>
#include <span>
#include <stdio.h>
#include <system_error>
#include <unistd.h>

#include <nlohmann/json.hpp>

using namespace casdwrap;

class NestedServer {
  public:
    NestedServer(const buildboxcommon::GrpcClient &grpcClient,
                 buildboxcommon::LocalContentAddressableStorage::StubInterface
                     *localCasClient,
                 const std::string &rootPath, const std::string &socketPath)
    {
        readerWriter = localCasClient->NestedServer(&context);

        buildboxcommon::NestedServerRequest request;
        request.set_instance_name(grpcClient.instanceName());
        request.set_path(rootPath);

        // Specify the credentials used to access the socket
        const auto creds =
            buildboxcommon::SystemUtils::getProcessCredentials();
        request.mutable_access_credentials()->set_uid(creds.uid);
        request.mutable_access_credentials()->set_gid(creds.gid);

        buildboxcommon::RemoteApisSocketConfig remoteApisSocketConfig;
        remoteApisSocketConfig.set_instance_name(grpcClient.instanceName());
        remoteApisSocketConfig.set_path(socketPath);
        remoteApisSocketConfig.set_action_cache_update_enabled(true);
        request.mutable_remote_apis_socket()->CopyFrom(remoteApisSocketConfig);

        BUILDBOX_LOG_DEBUG("Creating nested server for \"" << request.path()
                                                           << "\"");

        const bool successfulWrite = readerWriter->Write(request);
        if (!successfulWrite) {
            const grpc::Status writeErrorStatus = readerWriter->Finish();
            BUILDBOXCOMMON_THROW_EXCEPTION(
                std::runtime_error, "Error creating nested server for \""
                                        << request.path() << "\": \""
                                        << writeErrorStatus.error_message()
                                        << "\"");
        }

        buildboxcommon::NestedServerResponse response;
        const bool successfulRead = readerWriter->Read(&response);
        if (!successfulRead) {
            readerWriter->WritesDone();
            const grpc::Status readErrorStatus = readerWriter->Finish();
            BUILDBOXCOMMON_THROW_EXCEPTION(
                std::runtime_error, "Error creating nested server for \""
                                        << request.path() << "\": \""
                                        << readErrorStatus.error_message()
                                        << "\"");
        }
    }

    ~NestedServer()
    {
        // Send an empty request to tell buildbox-casd to clean up
        readerWriter->Write(buildboxcommon::NestedServerRequest());
        readerWriter->WritesDone();
    }

  private:
    grpc::ClientContext context;
    std::unique_ptr<grpc::ClientReaderWriterInterface<
        buildboxcommon::NestedServerRequest,
        buildboxcommon::NestedServerResponse>>
        readerWriter;
};

namespace {

int wrapBwrap(const std::string &toolPath,
              const std::vector<std::string> &command,
              const buildboxcommon::GrpcClient &grpcClient,
              buildboxcommon::LocalContentAddressableStorage::StubInterface
                  *localCasClient,
              const std::string &socketPath, bool clearEnv)
{
    const auto separator = std::ranges::find(command, "--");
    if (separator == command.end()) {
        std::cerr << "Error: bwrap must be invoked with `--` as separator "
                     "before the command"
                  << std::endl;
        return 1;
    }

    std::vector<std::string> toolArgs;
    toolArgs.push_back(toolPath);

    toolArgs.insert(toolArgs.end(), command.begin() + 1, separator);

    auto infoFDs = buildboxcommon::SystemUtils::createPipe();
    auto blockFDs = buildboxcommon::SystemUtils::createPipe();

    toolArgs.push_back("--info-fd");
    toolArgs.push_back(std::to_string(infoFDs[1]));

    toolArgs.push_back("--block-fd");
    toolArgs.push_back(std::to_string(blockFDs[0]));

    // Let bwrap create a dummy directory as last operation to be able to
    // detect when bwrap has finished setting up the filesystem
    toolArgs.push_back("--dir");
    toolArgs.push_back("/" + socketPath);

    toolArgs.insert(toolArgs.end(), separator, command.end());

    const auto pid = fork();
    if (pid == -1) {
        throw std::system_error(errno, std::system_category(),
                                "Error calling `fork()`");
    }

    if (pid == 0) {
        // Child process

        // Remove CLOEXEC flag from write end of info pipe
        if (fcntl(infoFDs[1], F_SETFD, 0) < 0) {
            _Exit(1);
        }

        // Remove CLOEXEC flag from read end of block pipe
        if (fcntl(blockFDs[0], F_SETFD, 0) < 0) {
            _Exit(1);
        }

        int status;
        if (clearEnv) {
            status = buildboxcommon::SystemUtils::executeCommand(
                toolArgs, false,
                std::unordered_map<std::string, std::string>{});
        }
        else {
            status = buildboxcommon::SystemUtils::executeCommand(toolArgs);
        }
        perror("Error calling `execve()`");
        _Exit(status);
    }

    // Parent process

    close(infoFDs[1]);
    close(blockFDs[0]);

    FILE *infoFile = fdopen(infoFDs[0], "r");
    auto json = nlohmann::json::parse(infoFile);
    fclose(infoFile);

    if (!json.contains("child-pid")) {
        throw std::runtime_error(
            "bwrap JSON object is missing the `child-pid` key");
    }

    pid_t bwrapChildPid = json["child-pid"].get<int>();

    BUILDBOX_LOG_DEBUG("Child PID \"" << bwrapChildPid << "\"");

    const std::string commandRoot =
        "/proc/" + std::to_string(bwrapChildPid) + "/root";

    BUILDBOX_LOG_DEBUG("Waiting for bwrap to setup the filesystem");

    // Wait for bwrap to create dummy directory as indication that
    // the filesystem setup is complete
    const std::string dummyDirectory = commandRoot + "/" + socketPath;
    while (!buildboxcommon::FileUtils::isDirectory(dummyDirectory.c_str())) {
        bool timedOut = false;
        int exitCode = buildboxcommon::SystemUtils::waitPidOrSignalWithTimeout(
            pid,
            std::optional<std::chrono::milliseconds>{
                std::chrono::milliseconds(100)},
            &timedOut);
        if (exitCode >= 0) {
            BUILDBOX_LOG_ERROR("Error: bwrap unexpectedly exited with code "
                               << exitCode);
            return exitCode == 0 ? 1 : exitCode;
        }
    }
    // Delete dummy directory
    rmdir(dummyDirectory.c_str());

    NestedServer nestedServer(grpcClient, localCasClient, commandRoot,
                              socketPath);

    // Let bwrap execute the command
    write(blockFDs[1], "+", 1);
    close(blockFDs[1]);

    return buildboxcommon::SystemUtils::waitPid(pid);
}

int wrapGeneric(const std::string &toolPath,
                const std::vector<std::string> &command,
                const buildboxcommon::GrpcClient &grpcClient,
                buildboxcommon::LocalContentAddressableStorage::StubInterface
                    *localCasClient,
                const std::string &rootPath, const std::string &socketPath,
                bool clearEnv)
{
    NestedServer nestedServer(grpcClient, localCasClient, rootPath,
                              socketPath);

    std::vector<std::string> toolArgs;
    toolArgs.push_back(toolPath);
    toolArgs.insert(toolArgs.end(), command.begin() + 1, command.end());

    const auto pid = fork();
    if (pid == -1) {
        throw std::system_error(errno, std::system_category(),
                                "Error calling `fork()`");
    }

    if (pid == 0) {
        // Child process

        int status;
        if (clearEnv) {
            status = buildboxcommon::SystemUtils::executeCommand(
                toolArgs, false,
                std::unordered_map<std::string, std::string>{});
        }
        else {
            status = buildboxcommon::SystemUtils::executeCommand(toolArgs);
        }
        perror("Error calling `execve()`");
        _Exit(status);
    }

    // Parent process

    return buildboxcommon::SystemUtils::waitPid(pid);
}

} // namespace

int main(int argc, char *argv[])
{
    auto args = std::span(argv, argc);
    // Initialize logger
    buildboxcommon::logging::Logger::getLoggerInstance().initialize(args[0]);

    // Ignore SIGPIPE in case of using sockets + grpc without MSG_NOSIGNAL
    // support configured
    struct sigaction sa{};
    sigemptyset(&sa.sa_mask);
    sa.sa_handler = SIG_IGN;
    sa.sa_flags = 0;
    if (sigaction(SIGPIPE, &sa, nullptr) == -1) {
        BUILDBOX_LOG_ERROR("Unable to ignore SIGPIPE");
        exit(1);
    }

    // Connection Options object
    CmdLineSpec casdWrapSpec(
        buildboxcommon::ConnectionOptionsCommandLine("", ""));

    try {
        buildboxcommon::CommandLine cml(casdWrapSpec.d_spec);
        const bool success = cml.parse(argc, argv);
        if (!success || argc == 1) {
            cml.usage();
            return 1;
        }
        if (cml.exists("help") || cml.exists("version")) {
            return 0;
        }

        buildboxcommon::LogLevel logLevel = buildboxcommon::LogLevel::ERROR;
        if (!buildboxcommon::parseLoggingOptions(cml, logLevel)) {
            return 1;
        }
        BUILDBOX_LOG_SET_LEVEL(logLevel);

        std::string rootPath;
        if (cml.exists("root")) {
            rootPath = std::filesystem::absolute(cml.getString("root"));
        }

        std::string socketPath = cml.getString("socket-path");
        // Strip leading slashes
        while (!socketPath.empty() && socketPath.front() == '/') {
            socketPath = socketPath.substr(1);
        }
        if (socketPath.empty()) {
            std::cerr << "Error: Invalid socket path: '"
                      << cml.getString("socket-path") << "'" << std::endl;
            return 1;
        }

        buildboxcommon::ConnectionOptions connOptions;
        if (!buildboxcommon::ConnectionOptionsCommandLine::configureChannel(
                cml, "", &connOptions)) {
            return 1;
        }

        if (connOptions.d_url.empty()) {
            std::cerr << "Error: buildbox-casd URL is missing" << std::endl;
            return 1;
        }

        BUILDBOX_LOG_DEBUG("Connection: " << connOptions);

        buildboxcommon::GrpcClient grpcClient;
        grpcClient.init(connOptions);

        auto localCasClient =
            buildboxcommon::LocalContentAddressableStorage::NewStub(
                grpcClient.channel());

        const std::string toolPath =
            buildboxcommon::SystemUtils::getPathToCommand(
                casdWrapSpec.d_command.at(0));
        if (toolPath.empty()) {
            std::cerr << "Error: Command not found: "
                      << casdWrapSpec.d_command[0] << std::endl;
            return 1;
        }

        const std::string toolBasename =
            buildboxcommon::FileUtils::pathBasename(toolPath);

        bool clearEnv = cml.getBool("clearenv");

        if (!rootPath.empty()) {
            return wrapGeneric(toolPath, casdWrapSpec.d_command, grpcClient,
                               localCasClient.get(), rootPath, socketPath,
                               clearEnv);
        }
        else if (toolBasename == "bwrap") {
            return wrapBwrap(toolPath, casdWrapSpec.d_command, grpcClient,
                             localCasClient.get(), socketPath, clearEnv);
        }
        else {
            std::cerr << "Error: Unsupported sandboxing tool '" << toolBasename
                      << "'" << std::endl;
            return 1;
        }
    }
    catch (const std::exception &e) {
        BUILDBOX_LOG_ERROR("Error in buildbox-casd-wrap: " << e.what());
        return 1;
    }
}
