// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/miopen_common.h"
#include "fast_gelu.h"
#include "core/providers/rocm/tensor/gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#ifdef USE_ROCM
#include "contrib_ops/rocm/bert/elementwise.h"
#else
#include "contrib_ops/rocm/bert/transformer_common.h"
#endif

namespace onnxruntime {
namespace contrib {
namespace rocm {

#define REGISTER_KERNEL_TYPED(T)                                  \
  ONNX_OPERATOR_TYPED_KERNEL_EX(                                  \
      FastGelu,                                                   \
      kMSDomain,                                                  \
      1,                                                          \
      T,                                                          \
      kRocmExecutionProvider,                                     \
      (*KernelDefBuilder::Create())                               \
          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
      FastGelu<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

using namespace ONNX_NAMESPACE;

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
#ifndef USE_ROCM
  const TransformerOptions* options = TransformerOptions::GetInstance();
  use_half2_ = !options->DisableHalf2();
#endif
}

template <typename T>
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
  ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));

  const Tensor* input = context->Input<Tensor>(0);
  const Tensor* bias = context->Input<Tensor>(1);
  Tensor* output = context->Output(0, input->Shape());

  int64_t input_length = input->Shape().Size();
  if (input_length == 0) {
    return Status::OK();
  }
  int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
  typedef typename ToHipType<T>::MappedType HipT;

#ifdef USE_ROCM
  return LaunchElementwiseKernel<functor::FastGeLU, HipT>(
      GetTuningContext(), context->GetComputeStream(),
      reinterpret_cast<const HipT*>(input->Data<T>()), static_cast<int>(input_length),
      (nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
      reinterpret_cast<HipT*>(output->MutableData<T>()));
#else
  return LaunchFastGeluKernel<HipT>(GetDeviceProp(),
                                     Stream(context),
                                     static_cast<int>(input_length),
                                     static_cast<int>(bias_length),
                                     reinterpret_cast<const HipT*>(input->Data<T>()),
                                     (nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
                                     reinterpret_cast<HipT*>(output->MutableData<T>()),
                                     use_half2_);
#endif
}

}  // namespace rocm
}  // namespace contrib
}  // namespace onnxruntime
