00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #ifndef FST_LIB_ENCODE_H__
00022 #define FST_LIB_ENCODE_H__
00023
00024 #include <climits>
00025 #include <tr1/unordered_map>
00026 using std::tr1::unordered_map;
00027 using std::tr1::unordered_multimap;
00028 #include <string>
00029 #include <vector>
00030 using std::vector;
00031 #include <fst/map.h>
00032 #include <fst/rmfinalepsilon.h>
00033
00034 namespace fst {
00035
00036 static const uint32 kEncodeLabels = 0x0001;
00037 static const uint32 kEncodeWeights = 0x0002;
00038 static const uint32 kEncodeFlags = 0x0003;
00039
00040 static const uint32 kEncodeHasISymbols = 0x0004;
00041 static const uint32 kEncodeHasOSymbols = 0x0008;
00042
00043 enum EncodeType { ENCODE = 1, DECODE = 2 };
00044
00045
00046 static const int32 kEncodeMagicNumber = 2129983209;
00047
00048
00049
00050
00051
00052
00053
00054 template <class A> class EncodeTable {
00055 public:
00056 typedef typename A::Label Label;
00057 typedef typename A::Weight Weight;
00058
00059
00060 struct Tuple {
00061 Tuple() {}
00062 Tuple(Label ilabel_, Label olabel_, Weight weight_)
00063 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
00064 Tuple(const Tuple& tuple)
00065 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
00066
00067 Label ilabel;
00068 Label olabel;
00069 Weight weight;
00070 };
00071
00072
00073 class TupleEqual {
00074 public:
00075 bool operator()(const Tuple* x, const Tuple* y) const {
00076 return (x->ilabel == y->ilabel &&
00077 x->olabel == y->olabel &&
00078 x->weight == y->weight);
00079 }
00080 };
00081
00082
00083
00084 class TupleKey {
00085 public:
00086 TupleKey()
00087 : encode_flags_(kEncodeLabels | kEncodeWeights) {}
00088
00089 TupleKey(const TupleKey& key)
00090 : encode_flags_(key.encode_flags_) {}
00091
00092 explicit TupleKey(uint32 encode_flags)
00093 : encode_flags_(encode_flags) {}
00094
00095 size_t operator()(const Tuple* x) const {
00096 size_t hash = x->ilabel;
00097 const int lshift = 5;
00098 const int rshift = CHAR_BIT * sizeof(size_t) - 5;
00099 if (encode_flags_ & kEncodeLabels)
00100 hash = hash << lshift ^ hash >> rshift ^ x->olabel;
00101 if (encode_flags_ & kEncodeWeights)
00102 hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
00103 return hash;
00104 }
00105
00106 private:
00107 int32 encode_flags_;
00108 };
00109
00110 typedef unordered_map<const Tuple*,
00111 Label,
00112 TupleKey,
00113 TupleEqual> EncodeHash;
00114
00115 explicit EncodeTable(uint32 encode_flags)
00116 : flags_(encode_flags),
00117 encode_hash_(1024, TupleKey(encode_flags)),
00118 isymbols_(0), osymbols_(0) {}
00119
00120 ~EncodeTable() {
00121 for (size_t i = 0; i < encode_tuples_.size(); ++i) {
00122 delete encode_tuples_[i];
00123 }
00124 delete isymbols_;
00125 delete osymbols_;
00126 }
00127
00128
00129 Label Encode(const A &arc) {
00130 const Tuple tuple(arc.ilabel,
00131 flags_ & kEncodeLabels ? arc.olabel : 0,
00132 flags_ & kEncodeWeights ? arc.weight : Weight::One());
00133 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
00134 if (it == encode_hash_.end()) {
00135 encode_tuples_.push_back(new Tuple(tuple));
00136 encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
00137 return encode_tuples_.size();
00138 } else {
00139 return it->second;
00140 }
00141 }
00142
00143
00144 Label GetLabel(const A &arc) const {
00145 const Tuple tuple(arc.ilabel,
00146 flags_ & kEncodeLabels ? arc.olabel : 0,
00147 flags_ & kEncodeWeights ? arc.weight : Weight::One());
00148 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
00149 if (it == encode_hash_.end()) {
00150 return kNoLabel;
00151 } else {
00152 return it->second;
00153 }
00154 }
00155
00156
00157 const Tuple* Decode(Label key) const {
00158 if (key < 1 || key > encode_tuples_.size()) {
00159 LOG(FATAL) << "EncodeTable::Decode: unknown decode key: " << key;
00160 return 0;
00161 }
00162 return encode_tuples_[key - 1];
00163 }
00164
00165 size_t Size() const { return encode_tuples_.size(); }
00166
00167 bool Write(ostream &strm, const string &source) const;
00168
00169 static EncodeTable<A> *Read(istream &strm, const string &source);
00170
00171 const uint32 flags() const { return flags_ & kEncodeFlags; }
00172
00173 int RefCount() const { return ref_count_.count(); }
00174 int IncrRefCount() { return ref_count_.Incr(); }
00175 int DecrRefCount() { return ref_count_.Decr(); }
00176
00177
00178 SymbolTable *InputSymbols() const { return isymbols_; }
00179
00180 SymbolTable *OutputSymbols() const { return osymbols_; }
00181
00182 void SetInputSymbols(const SymbolTable* syms) {
00183 if (isymbols_) delete isymbols_;
00184 if (syms) {
00185 isymbols_ = syms->Copy();
00186 flags_ |= kEncodeHasISymbols;
00187 } else {
00188 isymbols_ = 0;
00189 flags_ &= ~kEncodeHasISymbols;
00190 }
00191 }
00192
00193 void SetOutputSymbols(const SymbolTable* syms) {
00194 if (osymbols_) delete osymbols_;
00195 if (syms) {
00196 osymbols_ = syms->Copy();
00197 flags_ |= kEncodeHasOSymbols;
00198 } else {
00199 osymbols_ = 0;
00200 flags_ &= ~kEncodeHasOSymbols;
00201 }
00202 }
00203
00204 private:
00205 uint32 flags_;
00206 vector<Tuple*> encode_tuples_;
00207 EncodeHash encode_hash_;
00208 RefCounter ref_count_;
00209 SymbolTable *isymbols_;
00210 SymbolTable *osymbols_;
00211
00212 DISALLOW_COPY_AND_ASSIGN(EncodeTable);
00213 };
00214
00215 template <class A> inline
00216 bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
00217 WriteType(strm, kEncodeMagicNumber);
00218 WriteType(strm, flags_);
00219 int64 size = encode_tuples_.size();
00220 WriteType(strm, size);
00221 for (size_t i = 0; i < size; ++i) {
00222 const Tuple* tuple = encode_tuples_[i];
00223 WriteType(strm, tuple->ilabel);
00224 WriteType(strm, tuple->olabel);
00225 tuple->weight.Write(strm);
00226 }
00227
00228 if (flags_ & kEncodeHasISymbols)
00229 isymbols_->Write(strm);
00230
00231 if (flags_ & kEncodeHasOSymbols)
00232 osymbols_->Write(strm);
00233
00234 strm.flush();
00235 if (!strm) {
00236 LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
00237 return false;
00238 }
00239 return true;
00240 }
00241
00242 template <class A> inline
00243 EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
00244 int32 magic_number = 0;
00245 ReadType(strm, &magic_number);
00246 if (magic_number != kEncodeMagicNumber) {
00247 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
00248 return 0;
00249 }
00250 uint32 flags;
00251 ReadType(strm, &flags);
00252 EncodeTable<A> *table = new EncodeTable<A>(flags);
00253
00254 int64 size;
00255 ReadType(strm, &size);
00256 if (!strm) {
00257 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
00258 return 0;
00259 }
00260
00261 for (size_t i = 0; i < size; ++i) {
00262 Tuple* tuple = new Tuple();
00263 ReadType(strm, &tuple->ilabel);
00264 ReadType(strm, &tuple->olabel);
00265 tuple->weight.Read(strm);
00266 if (!strm) {
00267 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
00268 return 0;
00269 }
00270 table->encode_tuples_.push_back(tuple);
00271 table->encode_hash_[table->encode_tuples_.back()] =
00272 table->encode_tuples_.size();
00273 }
00274
00275 if (flags & kEncodeHasISymbols)
00276 table->isymbols_ = SymbolTable::Read(strm, source);
00277
00278 if (flags & kEncodeHasOSymbols)
00279 table->osymbols_ = SymbolTable::Read(strm, source);
00280
00281 return table;
00282 }
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303 template <class A> class EncodeMapper {
00304 typedef typename A::Weight Weight;
00305 typedef typename A::Label Label;
00306 public:
00307 EncodeMapper(uint32 flags, EncodeType type)
00308 : flags_(flags), type_(type),
00309 table_(new EncodeTable<A>(flags)) {}
00310
00311 EncodeMapper(const EncodeMapper& mapper)
00312 : flags_(mapper.flags_),
00313 type_(mapper.type_),
00314 table_(mapper.table_) {
00315 table_->IncrRefCount();
00316 }
00317
00318
00319 EncodeMapper(const EncodeMapper& mapper, EncodeType type)
00320 : flags_(mapper.flags_),
00321 type_(type),
00322 table_(mapper.table_) {
00323 table_->IncrRefCount();
00324 }
00325
00326 ~EncodeMapper() {
00327 if (!table_->DecrRefCount()) delete table_;
00328 }
00329
00330 A operator()(const A &arc);
00331
00332 MapFinalAction FinalAction() const {
00333 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
00334 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
00335 }
00336
00337 MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
00338
00339 MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
00340
00341 uint64 Properties(uint64 props) {
00342 uint64 mask = kFstProperties;
00343 if (flags_ & kEncodeLabels)
00344 mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
00345 if (flags_ & kEncodeWeights)
00346 mask &= kILabelInvariantProperties & kWeightInvariantProperties &
00347 (type_ == ENCODE ? kAddSuperFinalProperties :
00348 kRmSuperFinalProperties);
00349 return props & mask;
00350 }
00351
00352 const uint32 flags() const { return flags_; }
00353 const EncodeType type() const { return type_; }
00354 const EncodeTable<A> &table() const { return *table_; }
00355
00356 bool Write(ostream &strm, const string& source) {
00357 return table_->Write(strm, source);
00358 }
00359
00360 bool Write(const string& filename) {
00361 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
00362 if (!strm) {
00363 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
00364 return false;
00365 }
00366 return Write(strm, filename);
00367 }
00368
00369 static EncodeMapper<A> *Read(istream &strm,
00370 const string& source,
00371 EncodeType type = ENCODE) {
00372 EncodeTable<A> *table = table->Read(strm, source);
00373 return table ? new EncodeMapper(table->flags(), type, table) : 0;
00374 }
00375
00376 static EncodeMapper<A> *Read(const string& filename,
00377 EncodeType type = ENCODE) {
00378 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
00379 if (!strm) {
00380 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
00381 return false;
00382 }
00383 return Read(strm, filename, type);
00384 }
00385
00386 SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
00387
00388 SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
00389
00390 void SetInputSymbols(const SymbolTable* syms) {
00391 table_->SetInputSymbols(syms);
00392 }
00393
00394 void SetOutputSymbols(const SymbolTable* syms) {
00395 table_->SetOutputSymbols(syms);
00396 }
00397
00398 private:
00399 uint32 flags_;
00400 EncodeType type_;
00401 EncodeTable<A>* table_;
00402
00403 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
00404 : flags_(flags), type_(type), table_(table) {}
00405 void operator=(const EncodeMapper &);
00406 };
00407
00408 template <class A> inline
00409 A EncodeMapper<A>::operator()(const A &arc) {
00410 if (type_ == ENCODE) {
00411 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
00412 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
00413 arc.weight == Weight::Zero())) {
00414 return arc;
00415 } else {
00416 Label label = table_->Encode(arc);
00417 return A(label,
00418 flags_ & kEncodeLabels ? label : arc.olabel,
00419 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
00420 arc.nextstate);
00421 }
00422 } else if (type_ == DECODE) {
00423 if (arc.nextstate == kNoStateId) {
00424 return arc;
00425 } else {
00426 if (arc.ilabel == 0) return arc;
00427 if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel)
00428 LOG(FATAL) << "EncodeMapper: Label-encoded arc has different "
00429 "input and output labels";
00430 if (flags_ & kEncodeWeights && arc.weight != Weight::One())
00431 LOG(FATAL) <<
00432 "EncodeMapper: Weight-encoded arc has non-trivial weight";
00433 const typename EncodeTable<A>::Tuple* tuple =
00434 table_->Decode(arc.ilabel);
00435 return A(tuple->ilabel,
00436 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
00437 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
00438 arc.nextstate);
00439 }
00440 } else {
00441 LOG(FATAL) << "EncodeMapper: Unknown operation";
00442 }
00443 }
00444
00445
00446
00447 template<class A> inline
00448 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
00449 mapper->SetInputSymbols(fst->InputSymbols());
00450 mapper->SetOutputSymbols(fst->OutputSymbols());
00451 Map(fst, mapper);
00452 }
00453
00454 template<class A> inline
00455 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
00456 Map(fst, EncodeMapper<A>(mapper, DECODE));
00457 RmFinalEpsilon(fst);
00458 fst->SetInputSymbols(mapper.InputSymbols());
00459 fst->SetOutputSymbols(mapper.OutputSymbols());
00460 }
00461
00462
00463
00464
00465
00466
00467
00468
00469 template <class A>
00470 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
00471 public:
00472 typedef A Arc;
00473 typedef EncodeMapper<A> C;
00474
00475 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
00476 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {
00477 encoder->SetInputSymbols(fst.InputSymbols());
00478 encoder->SetOutputSymbols(fst.OutputSymbols());
00479 }
00480
00481 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
00482 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
00483
00484
00485 EncodeFst(const EncodeFst<A> &fst, bool copy = false)
00486 : MapFst<A, A, C>(fst, copy) {}
00487
00488
00489 virtual EncodeFst<A> *Copy(bool safe = false) const {
00490 if (safe)
00491 LOG(FATAL) << "EncodeFst::Copy(true): not allowed.";
00492 return new EncodeFst(*this);
00493 }
00494 };
00495
00496
00497
00498
00499
00500
00501
00502
00503 template <class A>
00504 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
00505 public:
00506 typedef A Arc;
00507 typedef EncodeMapper<A> C;
00508 typedef MapFstImpl< A, A, EncodeMapper<A> > Impl;
00509 using ImplToFst<Impl>::GetImpl;
00510
00511 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
00512 : MapFst<A, A, C>(fst,
00513 EncodeMapper<A>(encoder, DECODE),
00514 MapFstOptions()) {
00515 GetImpl()->SetInputSymbols(encoder.InputSymbols());
00516 GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
00517 }
00518
00519
00520 DecodeFst(const DecodeFst<A> &fst, bool safe = false)
00521 : MapFst<A, A, C>(fst, safe) {}
00522
00523
00524 virtual DecodeFst<A> *Copy(bool safe = false) const {
00525 return new DecodeFst(*this, safe);
00526 }
00527 };
00528
00529
00530
00531 template <class A>
00532 class StateIterator< EncodeFst<A> >
00533 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
00534 public:
00535 explicit StateIterator(const EncodeFst<A> &fst)
00536 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
00537 };
00538
00539
00540
00541 template <class A>
00542 class ArcIterator< EncodeFst<A> >
00543 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
00544 public:
00545 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
00546 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
00547 };
00548
00549
00550
00551 template <class A>
00552 class StateIterator< DecodeFst<A> >
00553 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
00554 public:
00555 explicit StateIterator(const DecodeFst<A> &fst)
00556 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
00557 };
00558
00559
00560
00561 template <class A>
00562 class ArcIterator< DecodeFst<A> >
00563 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
00564 public:
00565 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
00566 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
00567 };
00568
00569
00570
00571 typedef EncodeFst<StdArc> StdEncodeFst;
00572
00573 typedef DecodeFst<StdArc> StdDecodeFst;
00574
00575 }
00576
00577 #endif /// FST_LIB_ENCODE_H__
00578