2012/10/20

C++で2種類のコンテナを同時にソートする

C++でデータを扱っていると,2種類のコンテナを同時にソートしたい場合が時々出てくる.どういうことかというと,
int main() {
    int keys[]   = {5, 3, 1, 2, 4};
    int values[] = {1, 2, 3, 4, 5};
    aoba::sort_by_key(keys, keys+5, values);
    for(int i=0; i<5; ++i) {
        std::cout << keys[i] << " ";   // 1 2 3 4 5
    }
    std::cout << std::endl;
    for(int i=0; i<5; ++i) {
        std::cout << values[i] << " "; // 3 4 2 5 1
    }
    std::cout << std::endl;
    return 0;
}
と,片方のソート結果を元にもう片方をソートするような感じ.これはCUDAライブラリであるthrustに標準搭載されており,sort_by_keyという名前がついている.

さて,こういうときに一からソートアルゴリズムを組み直すなんてことは馬鹿らしい.じゃあどうするか.基本的には既存のstd::sortに渡すイテレータを変形することで,なんとか解決しようとする.そうなると,boostにzip_iteratorが存在することに気付けば,イテレータをboost::tupleの組で表現することによって,まとめてソートさせればいいような気がしてくる.こんな風に:
std::sort(
    boost::make_zip_iteator(
        boost::make_tuple(
            keys_first, values_first)),
    boost::make_zip_iterator(
        boost::make_tuple(
            keys_last, values_last)),
    compare_functor());
しかし,そうは問屋が降ろさない.何故なら,boost::zip_iteratorはWritableではない.これはソートアルゴリズムの要件を満たしていないため,適用させようとしてもコンパイルエラーが発生してしまう!

それじゃあどのようにして解決するかというと,一番楽な方法としてはWritableなコンセプトを満たしたzipイテレータを新しく作ってやって,それを用いてstd::sortを行うのがいい.とはいっても,C++でそれをやるのは若干手間のかかる作業になる.具体的にはこんな感じで実装した:
#pragma once
#include <algorithm>
#include <functional>
#include <boost/iterator/iterator_facade.hpp>
#include <boost/tuple/tuple.hpp>
namespace aoba {
namespace detail {
template<typename KeyIterator, typename ValueIterator>
struct sort_keyvalue_iter_helper_type {
typedef boost::tuple<
typename std::iterator_traits<KeyIterator>::value_type,
typename std::iterator_traits<ValueIterator>::value_type> value_type;
typedef boost::tuple<
typename std::iterator_traits<KeyIterator>::reference,
typename std::iterator_traits<ValueIterator>::reference> reference;
};
template<typename KeyIterator, typename ValueIterator>
class sort_keyvalue_iterator
: public boost::iterator_facade<
sort_keyvalue_iterator<KeyIterator, ValueIterator>,
typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::value_type,
std::random_access_iterator_tag,
typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::reference,
typename std::iterator_traits<KeyIterator>::difference_type>
{
public:
sort_keyvalue_iterator() {}
sort_keyvalue_iterator(KeyIterator key_iter, ValueIterator value_iter)
: m_key_iter(key_iter), m_value_iter(value_iter) {}
private:
KeyIterator m_key_iter;
ValueIterator m_value_iter;
friend class boost::iterator_core_access;
void increment()
{
++m_key_iter;
++m_value_iter;
}
void decrement()
{
--m_key_iter;
--m_value_iter;
}
bool equal(const sort_keyvalue_iterator& other) const
{
return (m_key_iter == other.m_key_iter);
}
typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::reference dereference() const
{
return (typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::reference(
*m_key_iter, *m_value_iter));
}
void advance(
typename std::iterator_traits<KeyIterator>::difference_type n)
{
m_key_iter += n;
m_value_iter += n;
}
typename std::iterator_traits<KeyIterator>::difference_type
distance_to(const sort_keyvalue_iterator& other) const
{
return std::distance(m_key_iter, other.m_key_iter);
}
};
template<typename KeyIterator, typename ValueIterator, typename Compare>
struct sort_keyvalue_iter_compare
: public std::binary_function<
typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::value_type,
typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::value_type, bool>
{
public:
typedef typename sort_keyvalue_iter_helper_type<
KeyIterator, ValueIterator>::value_type T;
sort_keyvalue_iter_compare(Compare comp) : m_comp(comp) {}
bool operator()(const T& left, const T& right)
{
return m_comp(boost::get<0>(left), boost::get<0>(right));
}
private:
Compare m_comp;
};
template<typename KeyIterator, typename ValueIterator>
sort_keyvalue_iterator<KeyIterator, ValueIterator>
make_sort_keyvalue_iterator(
KeyIterator sort_iter, ValueIterator permute_iter)
{
return sort_keyvalue_iterator<KeyIterator, ValueIterator>(
sort_iter, permute_iter);
}
} // namespace detail
template<typename KeyIterator, typename ValueIterator, typename Compare>
void sort_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_begin, Compare comp)
{
std::sort(
detail::make_sort_keyvalue_iterator(
keys_first, values_begin),
detail::make_sort_keyvalue_iterator(
keys_last, values_begin + std::distance(keys_first, keys_last)),
detail::sort_keyvalue_iter_compare<
KeyIterator, ValueIterator, Compare>(comp));
}
template<typename KeyIterator, typename ValueIterator>
void sort_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_begin)
{
sort_by_key(
keys_first, keys_last, values_begin,
std::less<typename std::iterator_traits<KeyIterator>::value_type>());
}
template<typename KeyIterator, typename ValueIterator, typename Compare>
void stable_sort_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_begin, Compare comp)
{
std::stable_sort(
detail::make_sort_keyvalue_iterator(
keys_first, values_begin),
detail::make_sort_keyvalue_iterator(
keys_last, values_begin + std::distance(keys_first, keys_last)),
detail::sort_keyvalue_iter_compare<
KeyIterator, ValueIterator, Compare>(comp));
}
template<typename KeyIterator, typename ValueIterator>
void stable_sort_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_begin)
{
stable_sort_by_key(
keys_first, keys_last, values_begin,
std::less<typename std::iterator_traits<KeyIterator>::value_type>());
}
template<typename KeyIterator, typename ValueIterator, typename Compare>
void partial_sort_by_key(
KeyIterator keys_first, KeyIterator keys_middle, KeyIterator keys_last,
ValueIterator values_first, Compare comp)
{
std::partial_sort(
detail::make_sort_keyvalue_iterator(
keys_first,
values_first),
detail::make_sort_keyvalue_iterator(
keys_middle,
values_first + std::distance(keys_first, keys_middle)),
detail::make_sort_keyvalue_iterator(
keys_last,
values_first + std::distance(keys_first, keys_last)),
detail::sort_keyvalue_iter_compare<
KeyIterator, ValueIterator, Compare>(comp));
}
template<typename KeyIterator, typename ValueIterator>
void partial_sort_by_key(
KeyIterator keys_first, KeyIterator keys_middle, KeyIterator keys_last,
ValueIterator values_first)
{
partial_sort_by_key(
keys_first, keys_middle, keys_last, values_first,
std::less<typename std::iterator_traits<KeyIterator>::value_type>());
}
template<typename KeyIterator, typename ValueIterator,
typename OutputKeyIterator, typename OutputValueIterator,
typename Compare>
void partial_sort_copy_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_first,
OutputKeyIterator result_keys_first, OutputKeyIterator result_keys_last,
OutputValueIterator result_values_first, Compare comp)
{
std::partial_sort_copy(
detail::make_sort_keyvalue_iterator(
keys_first, values_first),
detail::make_sort_keyvalue_iterator(
keys_last,
values_first + std::distance(keys_first, keys_last)),
detail::make_sort_keyvalue_iterator(
result_keys_first,
result_values_first),
detail::make_sort_keyvalue_iterator(
result_keys_last,
result_values_first +
std::distance(result_keys_first, result_keys_last)),
detail::sort_keyvalue_iter_compare<
KeyIterator, ValueIterator, Compare>(comp));
}
template<typename KeyIterator, typename ValueIterator,
typename OutputKeyIterator, typename OutputValueIterator>
void partial_sort_copy_by_key(
KeyIterator keys_first, KeyIterator keys_last,
ValueIterator values_first,
OutputKeyIterator result_keys_first, OutputKeyIterator result_keys_last,
OutputValueIterator result_values_first)
{
partial_sort_copy_by_key(
keys_first, keys_last, values_first,
result_keys_first, result_keys_last, result_values_first,
std::less<typename std::iterator_traits<KeyIterator>::value_type>());
}
template<typename KeyIterator, typename ValueIterator, typename Compare>
void nth_element_by_key(
KeyIterator keys_first, KeyIterator nth_keys, KeyIterator keys_last,
ValueIterator values_first, Compare comp)
{
std::nth_element(
detail::make_sort_keyvalue_iterator(
keys_first, values_first),
detail::make_sort_keyvalue_iterator(
nth_keys,
values_first + std::distance(keys_first, nth_keys)),
detail::make_sort_keyvalue_iterator(
keys_last,
values_first + std::distance(keys_first, keys_last)),
detail::sort_keyvalue_iter_compare<
KeyIterator, ValueIterator, Compare>(comp));
}
template<typename KeyIterator, typename ValueIterator>
void nth_element_by_key(
KeyIterator keys_first, KeyIterator nth_keys, KeyIterator keys_last,
ValueIterator values_first)
{
nth_element_by_key(
keys_first, nth_keys, keys_last, values_first,
std::less<typename std::iterator_traits<KeyIterator>::value_type>());
}
} // namespace aoba
view raw sortings.hpp hosted with ❤ by GitHub
実装したといっても,実際には『Sorting two arrays simultaneously』の内容を若干手直しするだけなので基本的にはそんなに大変じゃないし,ただ利用するぶんには単純にsort_by_key()を呼び出すだけでよい.

しかし,なんでこんなにC++は行数が多くなってしまうのか.haskellなんか一行で済むのに.