#!/usr/bin/env python3
# SPIR-V built-in library: type conversion functions
#
# ===----------------------------------------------------------------------===//
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# ===----------------------------------------------------------------------===//
#
# This script generates the file convert-spirv.cl, which contains all of the
# SPIR-V conversion functions.

import itertools
import os
import sys

from os.path import dirname, join, abspath

sys.path.insert(0, abspath(join(dirname(__file__), "..", "..", "..", "generic")))

from gen_convert_common import (
    types,
    int_types,
    signed_types,
    unsigned_types,
    float_types,
    int64_types,
    float64_types,
    vector_sizes,
    half_sizes,
    saturation,
    rounding_modes,
    float_prefix,
    float_suffix,
    bool_type,
    unsigned_type,
    sizeof_type,
    limit_max,
    limit_min,
    conditional_guard,
    close_conditional_guard,
    clc_core_fn_name,
)

types.remove("char")
int_types.remove("char")
signed_types.remove("char")
rounding_modes = [""] + rounding_modes

print(
    """/* !!!! AUTOGENERATED FILE generated by convert_type.py !!!!!

   DON'T CHANGE THIS FILE. MAKE YOUR CHANGES TO convert_type.py AND RUN:
   $ ./generate-conversion-type-cl.sh

   SPIR-V type conversion functions

   ===----------------------------------------------------------------------===

   Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
   See https://llvm.org/LICENSE.txt for license information.
   SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

   ===----------------------------------------------------------------------===
*/

#include <core/clc_core.h>
#include <libspirv/spirv.h>

#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif

#ifdef cl_khr_fp64
#pragma OPENCL EXTENSION cl_khr_fp64 : enable

#if defined(__EMBEDDED_PROFILE__) && !defined(cles_khr_int64)
#error Embedded profile that supports cl_khr_fp64 also has to support cles_khr_int64
#endif

#endif

#ifdef cles_khr_int64
#pragma OPENCL EXTENSION cles_khr_int64 : enable
#endif

"""
)


def spirv_fn_name(src, dst, size="", mode="", sat="", force_sat_decoration=False):
    """
    This helper function returns the correct SPIR-V function name for a given source and destination
    type, with optional size, mode and saturation arguments.
    For saturated, 2 form can co-exist: SatConvertUToS/SatConvertSToU and SConvert/UConvert + _sat.
    By default, SatConvert* is emitted, force_decoration will emit the *Convert + _sat.
    """
    is_src_float = src in float_types
    is_src_unsigned = src in unsigned_types
    is_src_signed = src in signed_types
    is_dst_float = dst in float_types
    is_dst_unsigned = dst in unsigned_types
    is_dst_signed = dst in signed_types
    use_sat_insn = sat != "" and not force_sat_decoration

    if dst == "schar":
        dst = "char"

    if is_src_unsigned and is_dst_signed and use_sat_insn:
        return "__spirv_SatConvertUToS_R{DST}{N}".format(DST=dst, N=size)
    elif is_src_signed and is_dst_unsigned and use_sat_insn:
        return "__spirv_SatConvertSToU_R{DST}{N}".format(DST=dst, N=size)
    elif is_src_float and is_dst_signed:
        return "__spirv_ConvertFToS_R{DST}{N}{SAT}{MODE}".format(
            DST=dst, N=size, SAT=sat, MODE=mode
        )
    elif is_src_float and is_dst_unsigned:
        return "__spirv_ConvertFToU_R{DST}{N}{SAT}{MODE}".format(
            DST=dst, N=size, SAT=sat, MODE=mode
        )
    elif is_src_signed and is_dst_float:
        return "__spirv_ConvertSToF_R{DST}{N}{MODE}".format(DST=dst, N=size, MODE=mode)
    elif is_src_unsigned and is_dst_float:
        return "__spirv_ConvertUToF_R{DST}{N}{MODE}".format(DST=dst, N=size, MODE=mode)
    elif is_src_float and is_dst_float:
        return "__spirv_FConvert_R{DST}{N}{MODE}".format(DST=dst, N=size, MODE=mode)
    elif is_dst_unsigned:
        return "__spirv_UConvert_R{DST}{N}{SAT}".format(DST=dst, N=size, SAT=sat)
    elif is_dst_signed:
        return "__spirv_SConvert_R{DST}{N}{SAT}".format(DST=dst, N=size, SAT=sat)
    sys.stderr.write(
        "Unhandled param set: {}, {}, {}, {}, {}\n".format(src, dst, size, mode, sat)
    )
    assert False


def is_same_size(src, dst):
    return sizeof_type[src] == sizeof_type[dst]


def is_signed_unsigned_conversion(src, dst):
    return (src in unsigned_types and dst in signed_types) or (
        src in signed_types and dst in unsigned_types
    )


def generate_spirv_fn_impl(src, dst, size="", mode="", sat="", force_decoration=False):
    close_conditional = conditional_guard(src, dst)

    print(
        """_CLC_DEF _CLC_OVERLOAD _CLC_CONSTFN
{DST}{N} {FN}({SRC}{N} x)
{{
  return {CORE_FN}(x);
}}
""".format(
            FN=spirv_fn_name(
                src,
                dst,
                size=size,
                sat=sat,
                mode=mode,
                force_sat_decoration=force_decoration,
            ),
            CORE_FN=clc_core_fn_name(dst, size=size, sat=sat, mode=mode),
            SRC=src,
            DST=dst,
            N=size,
        )
    )

    close_conditional_guard(close_conditional)


def generate_spirv_fn(src, dst, size="", mode="", sat=""):
    generate_spirv_fn_impl(
        src, dst, size=size, mode=mode, sat=sat, force_decoration=False
    )
    # There is an alias for saturated conversion
    # if signed to unsigned or unsigned to signed conversion
    # and if the componant types are not equals
    if (
        sat != ""
        and is_signed_unsigned_conversion(src, dst)
        and not is_same_size(src, dst)
    ):
        generate_spirv_fn_impl(
            src, dst, size=size, mode=mode, sat=sat, force_decoration=True
        )


# __spirv_ConvertFToU /__spirv_ConvertFToS + sat + mode
for src in float_types:
    for dst in int_types:
        for size in vector_sizes:
            for mode in rounding_modes:
                for sat in saturation:
                    generate_spirv_fn(src, dst, size, mode, sat)

# __spirv_ConvertUToF / __spirv_ConvertSToF + mode
for src in int_types:
    for dst in float_types:
        for size in vector_sizes:
            for mode in rounding_modes:
                generate_spirv_fn(src, dst, size, mode)

# __spirv_FConvert + mode
for src in float_types:
    for dst in float_types:
        for size in vector_sizes:
            for mode in rounding_modes:
                generate_spirv_fn(src, dst, size, mode)

# __spirv_UConvert + sat
for src in int_types:
    for dst in unsigned_types:
        for size in vector_sizes:
            for sat in saturation:
                generate_spirv_fn(src, dst, size, sat=sat)

# __spirv_SConvert + sat
for src in int_types:
    for dst in signed_types:
        for size in vector_sizes:
            for sat in saturation:
                generate_spirv_fn(src, dst, size, sat=sat)
