// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>

template <typename InputType, typename OutputType>
struct MoeSmoothquantTypeConfig
{
    using XDataType           = InputType;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = OutputType;
    using ComputeDataType     = float;
};

// runtime args
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
{
};

// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename InputType_,
          typename OutputType_,
          ck_tile::index_t Repeat_M_,         // each thread repeat along M
          ck_tile::index_t Repeat_N_,         // each thread repeat along N
          ck_tile::index_t ThreadPerBlock_M_, // num threads along M
          ck_tile::index_t ThreadPerBlock_N_, // num threads along N
          ck_tile::index_t Vector_N_,         // vector size along N
          bool kPadN_,
          bool kTwoPass_>
struct moe_smoothquant_traits_
{
    using InputType  = ck_tile::remove_cvref_t<InputType_>;
    using OutputType = ck_tile::remove_cvref_t<OutputType_>;

    static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
    static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
    static constexpr ck_tile::index_t total_warps =
        (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();

    // num of warps along m
    static constexpr ck_tile::index_t BlockWarps_M = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
            return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
        }
        else
        {
            // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
            return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
        }
    }();

    // num of warps along n
    static constexpr ck_tile::index_t BlockWarps_N = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
            return 1;
        }
        else
        {
            static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
            return ThreadPerBlock_N_ / ck_tile::get_warp_size();
        }
    }();

    static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
    static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

    static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
    static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;

    static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
    static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;

    using BlockTile  = ck_tile::sequence<Block_M, Block_N>;
    using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
    using WarpTile   = ck_tile::sequence<Warp_M, Warp_N>;
    using Vector     = ck_tile::sequence<1, Vector_N_>;

    using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;

    static constexpr bool kPadN    = kPadN_;
    static constexpr bool kTwoPass = kTwoPass_;
};

template <typename Traits_>
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);

// This is the public API, will be generated by script
struct moe_smoothquant_traits
{
    std::string in_type;  // input type
    std::string out_type; // output type
};

float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
