// Copyright California Institute of Technology 2025
//
// simple-frames is free software; you may redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 3 (GPLv3) of the
// License or at your discretion, any later version.
//
// simple-frames is distributed in the hope that it will be useful, but
// without any warranty or even the implied warranty of merchantability
// or fitness for a particular purpose. See the GNU General Public
// License (GPLv3) for more details.
//
// Neither the names of the California Institute of Technology (Caltech),
// The Massachusetts Institute of Technology (M.I.T), The Laser
// Interferometer Gravitational-Wave Observatory (LIGO), nor the names
// of its contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// You should have received a copy of the licensing terms for this
// software included in the file COPYING-GPL-3 located in the top-level
// directory of this package. If you did not, you can view a copy at
// http://dcc.ligo.org/M1500244/LICENSE
//
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>

#include <framecpp/FrameCPP.hh>
#include <framecpp/IFrameStream.hh>

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace
{
    template<typename T>
    std::vector<double>
    to_double_vec(const T* input, std::size_t len)
    {
        std::vector<double> output{};
        output.reserve(len);
        std::copy(input, input+len, std::back_inserter(output));
        return output;
    }

    struct FrameTimes
    {
        std::uint64_t gps{0};
        std::uint64_t dt{0};
    };
}

enum DataType
{
    Int16,
    Int32,
    Int64,
    Float32,
    Float64,
    UInt32,
};

class ChannelData
{
public:
    ChannelData(std::string name, std::uint64_t gps, std::uint64_t duration, std::vector<double> data):
    name_{std::move(name)}, gps_{gps}, duration_{duration}, data_{std::move(data)}
    {}
    std::string name() const { return name_; }
    std::uint64_t gps_start() const { return gps_; }
    std::uint64_t duration() const { return duration_; }
    std::uint64_t gps_end() const { return gps_ + duration_; }
    DataType data_type() const { return DataType::Float64; }

    double data_rate() const
    {
        if (data_.empty() || duration_ == 0)
        {
            return 0.0;
        }
        return static_cast<double>(data_.size())/static_cast<double>(duration_);
    }

    std::vector<double>
    data() const
    {
        return data_;
    }
private:
    std::string name_;
    std::uint64_t gps_;
    std::uint64_t duration_;
    std::vector<double> data_;
};

class SimpleFrame
{
public:
    explicit SimpleFrame(const std::string& filename): read_stream_(filename.c_str()) {};

    std::size_t
    num_frames() const
    {
        return read_stream_.GetNumberOfFrames(  );
    }

    std::size_t
    num_channels()
    {
        std::size_t count = 0;
        auto frame = load_frame_n(0);
        auto raw = frame->GetRawData();
        auto& proc_ref = frame->RefProcData();

        count = proc_ref.size();
        if (raw)
        {
            count += raw->RefFirstAdc(  ).size();
        }
        return count;
    }

    std::vector<std::string>
    channel_names()
    {
        std::vector<std::string> result{};
        auto frame = load_frame_n(0);
        auto raw = frame->GetRawData();
        auto& proc_ref = frame->RefProcData();

        std::size_t count = proc_ref.size();
        if (raw)
        {
            count += raw->RefFirstAdc(  ).size();
        }
        result.reserve(count);

        if (raw)
        {
            for (const auto &adc:raw->RefFirstAdc())
            {
                auto &name = adc->GetName();
                result.emplace_back(name);
            }
        }
        for (const auto& proc:proc_ref)
        {
            auto& name = proc->GetName();
            result.emplace_back(name);
        }
        return result;
    }

    ChannelData
    read_channel_as_double(const std::string& channel_name, INT_4U frame_number = 0)
    {
        return read_as_double(channel_name, frame_number, get_frame_times( frame_number ));
    }

    std::unordered_map<std::string, std::shared_ptr<ChannelData>>
    read_channels_as_double(std::vector<std::string> channel_names, INT_4U frame_number = 0)
    {
        std::unordered_map<std::string, std::shared_ptr<ChannelData>> result;
        auto times = get_frame_times( frame_number );
        for (const auto& channel_name:channel_names)
        {
            try
            {
                auto data = std::make_shared<ChannelData>(read_as_double( channel_name, frame_number, times ));
                result.insert(std::make_pair(channel_name, data));
            }
            catch(...)
            {}
        }
        return result;
    }


private:
    ChannelData
    read_as_double(const std::string& channel_name, INT_4U frame_number, const FrameTimes& frame_times)
    {
        std::vector<double> output{};
        boost::shared_ptr<FrameCPP::Version::FrAdcData> adc{nullptr};
        boost::shared_ptr<FrameCPP::Version::FrProcData> proc{nullptr};
        boost::shared_ptr<FrameCPP::Version::FrVect> ref_data{nullptr};

        try
        {
            adc = read_stream_.ReadFrAdcData( frame_number, channel_name );
            ref_data = adc->RefData()[0];
        }
        catch (...)
        {
            proc = read_stream_.ReadFrProcData( frame_number, channel_name );
            ref_data = proc->RefData()[0];
        }

        auto data = ref_data->GetDataUncompressed();
        auto data_len = ref_data->GetNBytes(  );

        switch (ref_data->GetType())
        {
        case FrameCPP::Compression::FR_VECT_2S:
            output = to_double_vec<std::int16_t>(reinterpret_cast< const std::int16_t* >(data.get( )), data_len/sizeof(std::int16_t));
            break;
        case FrameCPP::Compression::FR_VECT_4S:
            output = to_double_vec<std::int32_t>(reinterpret_cast< const std::int32_t* >(data.get( )), data_len/sizeof(std::int32_t));
            break;
        case FrameCPP::Compression::FR_VECT_8S:
            output = to_double_vec<std::int64_t>(reinterpret_cast< const std::int64_t* >(data.get( )), data_len/sizeof(std::int64_t));
            break;
        case FrameCPP::Compression::FR_VECT_4R:
            output = to_double_vec<float>(reinterpret_cast< const float* >(data.get( )), data_len/sizeof(float));
            break;
        case FrameCPP::Compression::FR_VECT_8R:
            output = to_double_vec<double>(reinterpret_cast< const double* >(data.get( )), data_len/sizeof(double));
            break;
        case FrameCPP::Compression::FR_VECT_4U:
            output = to_double_vec<std::uint32_t>(reinterpret_cast< const std::uint32_t* >(data.get( )), data_len/sizeof(std::uint32_t));
            break;
        default:
            throw std::runtime_error("unsupported data type");
        }
        return {channel_name, frame_times.gps, frame_times.dt, output};
    }

    FrameTimes
    get_frame_times(INT_4U frame_number)
    {
        FrameTimes result{};
        auto frame = load_frame_n(frame_number);
        result.gps = frame->GetGTime(  ).GetSeconds();
        result.dt = static_cast<std::uint64_t>(frame->GetDt());
        return result;
    }

    boost::shared_ptr<FrameCPP::FrameH>
    load_frame_n(INT_4U index)
    {
        auto it = frame_cache_.find(index);
        if (it != frame_cache_.end())
        {
            return it->second;
        }
        auto frame = read_stream_.ReadFrameN(index);
        frame_cache_.insert(std::make_pair(index, frame));
        return frame;
    }
    // void
    // load_toc()
    // {
    //     if (toc_)
    //     {
    //         return;
    //     }
    //     toc_ = read_stream_.GetTOC(  );
    // }

    FrameCPP::IFrameFStream read_stream_;
    std::unordered_map<INT_4U, boost::shared_ptr<FrameCPP::FrameH>> frame_cache_{};
    //boost::shared_ptr<const FrameCPP::FrTOC> toc_{nullptr};
};

PYBIND11_MODULE(simple_frames, m) {
    m.doc() = "A simple interface for reading (and maybe eventually writing) frames";

    pybind11::enum_<DataType>(m, "DataType")
    .value("Int16", DataType::Int16)
    .value("Int32", DataType::Int32)
    .value("Int64", DataType::Int64)
    .value("Float32", DataType::Float32)
    .value("Float64", DataType::Float64)
    .value("UInt32", DataType::UInt32)
    .export_values();

    pybind11::class_<ChannelData, std::shared_ptr<ChannelData>>(m, "ChannelData")
    .def_property_readonly( "name", &ChannelData::name )
    .def_property_readonly( "gps_start", &ChannelData::gps_start )
    .def_property_readonly( "gps_end", &ChannelData::gps_end )
    .def_property_readonly( "duration", &ChannelData::duration )
    .def_property_readonly( "data_rate", &ChannelData::data_rate )
    .def_property_readonly( "data_type", &ChannelData::data_type )
    .def_property_readonly( "data", &ChannelData::data );

    pybind11::class_<SimpleFrame>(m, "SimpleFrame")
    .def(pybind11::init<const std::string&>())
    .def("num_frames", &SimpleFrame::num_frames)
    .def("num_channels", &SimpleFrame::num_channels)
    .def("channel_names", &SimpleFrame::channel_names)
    .def("read_channel_as_double", &SimpleFrame::read_channel_as_double, pybind11::arg("channel_name"), pybind11::arg("frame_number") = 0)
    .def("read_channels_as_double", &SimpleFrame::read_channels_as_double, pybind11::arg("channel_names"), pybind11::arg("frame_number") = 0);
}