// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <thread>
#include <utility>

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
#include "sharded_moe.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

#define CHECK_CUDA(res)     \
  if (res != cudaSuccess) { \
    cuda_result = res;      \
    return;                 \
  }

#define CHECK_NCCL(res)     \
  if (res != ncclSuccess) { \
    nccl_result = res;      \
    return;                 \
  }

#define REGISTER_KERNEL_TYPED(T)                                                                            \
  ONNX_OPERATOR_TYPED_KERNEL_EX(                                                                            \
      ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider,                                                  \
      (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
      ShardedMoE<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
  ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("tensor_shards", &tensor_shards_).IsOK());
  ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("local_experts_start_index", &local_experts_start_index_).IsOK());
  rank_to_experts_start_index_.resize(nccl_->Size());

  auto allocator = op_kernel_info.GetAllocator(OrtMemTypeDefault);
  ORT_ENFORCE(SynchronizeExpertsStartIndex(allocator) == Status::OK());
}

template <typename T>
Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
  typedef typename ToCudaType<T>::MappedType CudaT;
  auto stream = context->GetComputeStream();

  auto& device_prop = GetDeviceProp();
  const int sm = device_prop.major * 10 + device_prop.minor;

  AllocatorPtr allocator;
  ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

  const Tensor* input = context->Input<Tensor>(0);
  const Tensor* router_probs = context->Input<Tensor>(1);
  const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
  const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(3);
  const Tensor* fc2_experts_weights = context->Input<Tensor>(4);
  const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(5);
  const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(6);
  const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

  MoEParameters moe_params(tensor_shards_);
  ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
      moe_params, input, router_probs,
      fc1_experts_weights, fc1_experts_bias_optional, nullptr, nullptr,
      fc2_experts_weights, fc2_experts_bias_optional, nullptr, nullptr,
      fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, nullptr,
      1,  // no quantization so pack size is 1
      activation_type_ == ort_fastertransformer::ActivationType::SwiGLU,
      0));  // no block-wise quantization for sharded MoE

  ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

  ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
                                                                     activation_type_,
                                                                     fc3_experts_weights_optional != nullptr,
                                                                     normalize_routing_weights_,
                                                                     use_sparse_mixer_);

  size_t ws_size = moe_runner.getWorkspaceSize(
      static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
      static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));

  size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
  size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
  size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
  size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int);

  // TODO: allocate one buffer and reuse it.
  IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, ws_size, false, stream);
  IAllocatorUniquePtr<void> fc2_output = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
  IAllocatorUniquePtr<void> fc2_output_bc = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
  IAllocatorUniquePtr<void> expert_scales =
      IAllocator::MakeUniquePtr<void>(allocator, expert_scales_size, false, stream);
  IAllocatorUniquePtr<void> expanded_source_row_to_expanded_dest_row =
      IAllocator::MakeUniquePtr<void>(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream);
  IAllocatorUniquePtr<void> expert_for_source_row =
      IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);

  const CudaT* fc_scales_ptr = nullptr;

  moe_runner.run_moe_fc(
      reinterpret_cast<const CudaT*>(input->template Data<T>()),
      reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
      reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
      fc1_experts_bias_optional == nullptr
          ? nullptr
          : reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
      activation_type_,
      fc3_experts_weights_optional == nullptr
          ? nullptr
          : reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
      std::move(fc_scales_ptr),
      fc3_experts_bias_optional == nullptr
          ? nullptr
          : reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
      reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
      static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
      static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
      static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
      static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()), reinterpret_cast<CudaT*>(fc2_output.get()),
      reinterpret_cast<CudaT*>(expert_scales.get()),
      reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
      reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));

  Tensor* output = context->Output(0, input->Shape());

  if (moe_params.parallel_type == MoEParallelType::None) {
    fc2_output_bc = std::move(fc2_output);
  }

  if (moe_params.parallel_type == MoEParallelType::EPAndTP) {
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet");
  }

  if (moe_params.parallel_type == MoEParallelType::TP) {
    ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());

    ORT_RETURN_IF_ERROR(FuncCustomAllReduce(nccl_,
                                            Stream(context),
                                            fc2_output.get(),
                                            fc2_output_bc.get(),
                                            static_cast<int64_t>(fc2_output_size / sizeof(CudaT)),
                                            input->DataType(),
                                            collective::IPCMemoryResourcePack::GetGlobalInstance()));
  }

  if (moe_params.parallel_type == MoEParallelType::EP) {
    size_t stride_count = moe_params.hidden_size;
    size_t stride_bytes = stride_count * sizeof(CudaT);
    int64_t total_past_rows = 0;
    int64_t total_covered_rows = 0;

    NCCL_RETURN_IF_ERROR(ncclGroupStart());
    for (int rank = 0; rank < nccl_->Size(); ++rank) {
      int64_t experts_start_index = rank_to_experts_start_index_[rank];
      moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows,
                                     total_covered_rows);
      const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
      char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
      NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count,
                                         GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context)));
    }
    NCCL_RETURN_IF_ERROR(ncclGroupEnd());
  }

  ort_fastertransformer::finalize_moe_routing_kernelLauncher(
      reinterpret_cast<CudaT*>(fc2_output_bc.get()), reinterpret_cast<CudaT*>(output->template MutableData<T>()),
      fc2_experts_bias_optional == nullptr
          ? nullptr
          : reinterpret_cast<const CudaT*>(fc2_experts_bias_optional->template Data<T>()),
      reinterpret_cast<CudaT*>(expert_scales.get()),
      reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
      reinterpret_cast<int*>(expert_for_source_row.get()), static_cast<int>(moe_params.num_rows),
      static_cast<int>(moe_params.hidden_size), static_cast<int>(k_), Stream(context));

  return Status::OK();
}

template <typename T>
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator) const {
  using IndexType = int64_t;
  size_t IndexTypeSize = sizeof(IndexType);

  IAllocatorUniquePtr<IndexType> experts_start_index_d =
      IAllocator::MakeUniquePtr<IndexType>(allocator, 1, false);
  IAllocatorUniquePtr<IndexType> rank_to_experts_start_index_d =
      IAllocator::MakeUniquePtr<IndexType>(allocator, nccl_->Size(), false);

  CUDA_RETURN_IF_ERROR(cudaMemcpy(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize,
                                  cudaMemcpyHostToDevice));
  NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
                                     reinterpret_cast<char*>(rank_to_experts_start_index_d.get()), 1,
                                     GetNcclDataType(DataTypeImpl::GetType<IndexType>()), nccl_->Comm(),
                                     nullptr));

  CUDA_RETURN_IF_ERROR(cudaMemcpy(const_cast<int64_t*>(rank_to_experts_start_index_.data()),
                                  rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize,
                                  cudaMemcpyDeviceToHost));

  return Status::OK();
}
#endif

}  // namespace cuda
}  // namespace contrib
}  // namespace onnxruntime
